{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "9e87c927",
   "metadata": {},
   "source": [
    "<img src=\"http://developer.download.nvidia.com/notebooks/dlsw-notebooks/tensorrt_torchtrt_efficientnet/nvidia_logo.png\" width=\"90px\">\n",
    "\n",
    "# PySpark PyTorch Inference\n",
    "\n",
    "### Image Classification\n",
    "\n",
    "In this notebook, we will train an MLP to perform image classification on FashionMNIST, and load it for distributed inference with Spark.\n",
    "\n",
    "Based on: https://pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html  \n",
    "\n",
    "We also demonstrate accelerated inference via Torch-TensorRT model compilation.   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "91d7ec98",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "import shutil\n",
    "from torch import nn\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets\n",
    "from torchvision.transforms import ToTensor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f71f801d",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.mkdir('models') if not os.path.exists('models') else None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d714f40d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'2.5.1+cu124'"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d0f6fb37",
   "metadata": {},
   "source": [
    "### Load Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1c942a46",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Download training data from open datasets.\n",
    "training_data = datasets.FashionMNIST(\n",
    "    root=\"datasets/data\",\n",
    "    train=True,\n",
    "    download=True,\n",
    "    transform=ToTensor(),\n",
    ")\n",
    "\n",
    "# Download test data from open datasets.\n",
    "test_data = datasets.FashionMNIST(\n",
    "    root=\"datasets/data\",\n",
    "    train=False,\n",
    "    download=True,\n",
    "    transform=ToTensor(),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4a89aa8e-ef62-4aac-8260-4b004f2c1b55",
   "metadata": {},
   "outputs": [],
   "source": [
    "classes = [\n",
    "    \"T-shirt/top\",\n",
    "    \"Trouser\",\n",
    "    \"Pullover\",\n",
    "    \"Dress\",\n",
    "    \"Coat\",\n",
    "    \"Sandal\",\n",
    "    \"Shirt\",\n",
    "    \"Sneaker\",\n",
    "    \"Bag\",\n",
    "    \"Ankle boot\",\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "10a97111",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28]) torch.float32\n",
      "Shape of y: torch.Size([64]) torch.int64\n"
     ]
    }
   ],
   "source": [
    "batch_size = 64\n",
    "\n",
    "# Create data loaders.\n",
    "train_dataloader = DataLoader(training_data, batch_size=batch_size)\n",
    "test_dataloader = DataLoader(test_data, batch_size=batch_size)\n",
    "\n",
    "for X, y in test_dataloader:\n",
    "    print(f\"Shape of X [N, C, H, W]: {X.shape} {X.dtype}\")\n",
    "    print(f\"Shape of y: {y.shape} {y.dtype}\")\n",
    "    break"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca7af350",
   "metadata": {},
   "source": [
    "### Create model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "512d0bc7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using cuda device\n",
      "NeuralNetwork(\n",
      "  (linear_relu_stack): Sequential(\n",
      "    (0): Linear(in_features=784, out_features=512, bias=True)\n",
      "    (1): ReLU()\n",
      "    (2): Linear(in_features=512, out_features=512, bias=True)\n",
      "    (3): ReLU()\n",
      "    (4): Linear(in_features=512, out_features=10, bias=True)\n",
      "  )\n",
      ")\n"
     ]
    }
   ],
   "source": [
    "# Get cpu or gpu device for training.\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(f\"Using {device} device\")\n",
    "\n",
    "# Define model\n",
    "class NeuralNetwork(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(NeuralNetwork, self).__init__()\n",
    "        self.linear_relu_stack = nn.Sequential(\n",
    "            nn.Linear(28*28, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(512, 10)\n",
    "        )\n",
    "\n",
    "    def forward(self, x):\n",
    "        logits = self.linear_relu_stack(x)\n",
    "        return logits\n",
    "\n",
    "model = NeuralNetwork().to(device)\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4573c1b7",
   "metadata": {},
   "source": [
    "### Train Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "4d4f5538",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_fn = nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "92d9076a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def train(dataloader, model, loss_fn, optimizer):\n",
    "    size = len(dataloader.dataset)\n",
    "    model.train()\n",
    "    for batch, (X, y) in enumerate(dataloader):\n",
    "        X, y = X.to(device), y.to(device)\n",
    "        X = torch.flatten(X, start_dim=1, end_dim=-1)\n",
    "\n",
    "        # Zero gradients\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # Compute prediction error\n",
    "        pred = model(X)\n",
    "        loss = loss_fn(pred, y)\n",
    "\n",
    "        # Backpropagation\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        if batch % 100 == 0:\n",
    "            loss, current = loss.item(), (batch + 1) * len(X)\n",
    "            print(f\"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "11c5650d",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(dataloader, model, loss_fn):\n",
    "    size = len(dataloader.dataset)\n",
    "    num_batches = len(dataloader)\n",
    "    model.eval()\n",
    "    test_loss, correct = 0, 0\n",
    "    with torch.no_grad():\n",
    "        for X, y in dataloader:\n",
    "            X, y = X.to(device), y.to(device)\n",
    "            X = torch.flatten(X, start_dim=1, end_dim=-1)\n",
    "            pred = model(X)\n",
    "            test_loss += loss_fn(pred, y).item()\n",
    "            correct += (pred.argmax(1) == y).type(torch.float).sum().item()\n",
    "    test_loss /= num_batches\n",
    "    correct /= size\n",
    "    print(f\"Test Error: \\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \\n\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "854608e6",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1\n",
      "-------------------------------\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loss: 2.298206  [   64/60000]\n",
      "loss: 2.283203  [ 6464/60000]\n",
      "loss: 2.262282  [12864/60000]\n",
      "loss: 2.259791  [19264/60000]\n",
      "loss: 2.240928  [25664/60000]\n",
      "loss: 2.218922  [32064/60000]\n",
      "loss: 2.225280  [38464/60000]\n",
      "loss: 2.193091  [44864/60000]\n",
      "loss: 2.194699  [51264/60000]\n",
      "loss: 2.157922  [57664/60000]\n",
      "Test Error: \n",
      " Accuracy: 38.6%, Avg loss: 2.149652 \n",
      "\n",
      "Epoch 2\n",
      "-------------------------------\n",
      "loss: 2.164765  [   64/60000]\n",
      "loss: 2.153999  [ 6464/60000]\n",
      "loss: 2.094229  [12864/60000]\n",
      "loss: 2.107332  [19264/60000]\n",
      "loss: 2.060189  [25664/60000]\n",
      "loss: 2.009164  [32064/60000]\n",
      "loss: 2.033063  [38464/60000]\n",
      "loss: 1.954014  [44864/60000]\n",
      "loss: 1.968186  [51264/60000]\n",
      "loss: 1.892358  [57664/60000]\n",
      "Test Error: \n",
      " Accuracy: 54.1%, Avg loss: 1.883826 \n",
      "\n",
      "Epoch 3\n",
      "-------------------------------\n",
      "loss: 1.922989  [   64/60000]\n",
      "loss: 1.895849  [ 6464/60000]\n",
      "loss: 1.767882  [12864/60000]\n",
      "loss: 1.804950  [19264/60000]\n",
      "loss: 1.702711  [25664/60000]\n",
      "loss: 1.664090  [32064/60000]\n",
      "loss: 1.682484  [38464/60000]\n",
      "loss: 1.577310  [44864/60000]\n",
      "loss: 1.613093  [51264/60000]\n",
      "loss: 1.510797  [57664/60000]\n",
      "Test Error: \n",
      " Accuracy: 59.5%, Avg loss: 1.517127 \n",
      "\n",
      "Epoch 4\n",
      "-------------------------------\n",
      "loss: 1.588409  [   64/60000]\n",
      "loss: 1.558777  [ 6464/60000]\n",
      "loss: 1.393466  [12864/60000]\n",
      "loss: 1.465835  [19264/60000]\n",
      "loss: 1.350062  [25664/60000]\n",
      "loss: 1.359687  [32064/60000]\n",
      "loss: 1.370576  [38464/60000]\n",
      "loss: 1.287119  [44864/60000]\n",
      "loss: 1.330430  [51264/60000]\n",
      "loss: 1.238912  [57664/60000]\n",
      "Test Error: \n",
      " Accuracy: 62.4%, Avg loss: 1.254357 \n",
      "\n",
      "Epoch 5\n",
      "-------------------------------\n",
      "loss: 1.333722  [   64/60000]\n",
      "loss: 1.322049  [ 6464/60000]\n",
      "loss: 1.143545  [12864/60000]\n",
      "loss: 1.250494  [19264/60000]\n",
      "loss: 1.123120  [25664/60000]\n",
      "loss: 1.166146  [32064/60000]\n",
      "loss: 1.181268  [38464/60000]\n",
      "loss: 1.112326  [44864/60000]\n",
      "loss: 1.155791  [51264/60000]\n",
      "loss: 1.079376  [57664/60000]\n",
      "Test Error: \n",
      " Accuracy: 64.0%, Avg loss: 1.092456 \n",
      "\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "epochs = 5\n",
    "for t in range(epochs):\n",
    "    print(f\"Epoch {t+1}\\n-------------------------------\")\n",
    "    train(train_dataloader, model, loss_fn, optimizer)\n",
    "    test(test_dataloader, model, loss_fn)\n",
    "print(\"Done!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85d97839",
   "metadata": {},
   "source": [
    "### Save Model State Dict\n",
    "This saves the serialized object to disk using pickle."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "5d5d24de",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved PyTorch Model State to models/model.pt\n"
     ]
    }
   ],
   "source": [
    "torch.save(model.state_dict(), \"models/model.pt\")\n",
    "print(\"Saved PyTorch Model State to models/model.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ac221ca7-e227-4c8c-8577-1eeda4a61fc7",
   "metadata": {},
   "source": [
    "### Save Model as TorchScript\n",
    "This saves an [intermediate representation of the compute graph](https://pytorch.org/tutorials/beginner/saving_loading_models.html#export-load-model-in-torchscript-format), which does not require pickle (or even python). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6d9b3a45-7618-43e4-8bd3-8bb317a484d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved TorchScript Model to models/ts_model.pt\n"
     ]
    }
   ],
   "source": [
    "scripted = torch.jit.script(model)\n",
    "scripted.save(\"models/ts_model.pt\")\n",
    "print(\"Saved TorchScript Model to models/ts_model.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "12ee8916-f437-4a2a-9bf4-14ff5376d305",
   "metadata": {},
   "source": [
    "### Load Model State"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "8fe3b5d1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_from_state = NeuralNetwork().to(device)\n",
    "model_from_state.load_state_dict(torch.load(\"models/model.pt\", weights_only=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "0c405bd0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted: \"Ankle boot\", Actual: \"Ankle boot\"\n"
     ]
    }
   ],
   "source": [
    "model_from_state.eval()\n",
    "x, y = test_data[0][0], test_data[0][1]\n",
    "with torch.no_grad():\n",
    "    x = torch.flatten(x.to(device), start_dim=1, end_dim=-1)\n",
    "    pred = model_from_state(x)\n",
    "    predicted, actual = classes[pred[0].argmax(0)], classes[y]\n",
    "    print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "290c482a-1c5d-4bf2-bc3f-8a4e53d442b5",
   "metadata": {},
   "source": [
    "### Load Torchscript Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ef3c419e-d384-446c-b07b-1af93e07d6c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "ts_model = torch.jit.load(\"models/ts_model.pt\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c92d6cdb",
   "metadata": {},
   "outputs": [],
   "source": [
    "x, y = test_data[0][0], test_data[0][1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "038af830-a360-45eb-ab4e-b1adab0af164",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted: \"Ankle boot\", Actual: \"Ankle boot\"\n"
     ]
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    pred = ts_model(torch.flatten(x.to(device), start_dim=1, end_dim=-1))\n",
    "    predicted, actual = classes[pred[0].argmax(0)], classes[y]\n",
    "    print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76980495",
   "metadata": {},
   "source": [
    "### Compile using the Torch JIT Compiler\n",
    "This leverages the [Torch-TensorRT inference compiler](https://pytorch.org/TensorRT/) for accelerated inference on GPUs using the `torch.compile` JIT interface under the hood. The compiler stack returns a [boxed-function](http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/) that triggers compilation on the first call.  \n",
    "\n",
    "Modules compiled in this fashion are [not serializable with pickle](https://github.com/pytorch/pytorch/issues/101107#issuecomment-1542688089), so we cannot send the compiled model directly to Spark. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "414bc856",
   "metadata": {},
   "source": [
    "(You may see a warning about modelopt quantization. This is safe to ignore, as [implicit quantization](https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#intro-quantization) is deprecated in the latest TensorRT. See [this link](https://pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/vgg16_fp8_ptq.html) for a guide to explicit quantization.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "362b266b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch_tensorrt as trt\n",
    "import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "f0ac1362",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Optional: set the filename for the TensorRT timing cache\n",
    "timestamp = time.time()\n",
    "timing_cache = f\"/tmp/timing_cache-{timestamp}.bin\"\n",
    "with open(timing_cache, \"wb\") as f:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "f3e3bdc4",
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs_bs1 = torch.randn((10, 784), dtype=torch.float).to(\"cuda\")\n",
    "# This indicates dimension 0 of inputs_bs1 is dynamic whose range of values is [1, 50]. \n",
    "torch._dynamo.mark_dynamic(inputs_bs1, 0, min=1, max=64)\n",
    "trt_model = trt.compile(\n",
    "    model,\n",
    "    ir=\"torch_compile\",\n",
    "    inputs=inputs_bs1,\n",
    "    enabled_precisions={torch.float},\n",
    "    timing_cache_path=timing_cache,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "66f61302",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:torch_tensorrt.dynamo._compiler:Node linear_default of op type call_function does not have metadata. This could sometimes lead to undefined behavior.\n",
      "WARNING:torch_tensorrt.dynamo._compiler:Some nodes do not have metadata (shape and dtype information). This could lead to problems sometimes if the graph has PyTorch and TensorRT segments.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted: \"Ankle boot\", Actual: \"Ankle boot\"\n"
     ]
    }
   ],
   "source": [
    "stream = torch.cuda.Stream()\n",
    "with torch.no_grad(), torch.cuda.stream(stream):\n",
    "    pred = trt_model(torch.flatten(x.to(device), start_dim=1, end_dim=-1))\n",
    "    predicted, actual = classes[pred[0].argmax(0)], classes[y]\n",
    "    print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9ec04be8",
   "metadata": {},
   "source": [
    "### Compile using the Torch-TensorRT AOT Compiler\n",
    "Alternatively, use the Torch-TensorRT Dynamo backend for Ahead-of-Time (AOT) compilation to eagerly optimize the model in an explicit compilation phase. We first export the model to produce a traced graph representing the Tensor computation in an AOT fashion, which produces a `ExportedProgram` object which can be [serialized and reloaded](https://pytorch.org/TensorRT/user_guide/saving_models.html). We can then compile this IR using the Torch-TensorRT AOT compiler for inference.   \n",
    "\n",
    "[Read the docs](https://pytorch.org/TensorRT/user_guide/torch_tensorrt_explained.html) for more information on JIT vs AOT compilation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "3e7e7689",
   "metadata": {},
   "outputs": [],
   "source": [
    "example_inputs = (torch.randn((10, 784), dtype=torch.float).to(\"cuda\"),)\n",
    "\n",
    "# Mark dim 1 (batch size) as dynamic\n",
    "batch = torch.export.Dim(\"batch\", min=1, max=64)\n",
    "# Produce traced graph in ExportedProgram format\n",
    "exp_program = torch.export.export(model_from_state, args=example_inputs, dynamic_shapes={\"x\": {0: batch}})\n",
    "# Compile the traced graph to produce an optimized module\n",
    "trt_gm = trt.dynamo.compile(exp_program, tuple(example_inputs), enabled_precisions={torch.float}, timing_cache_path=timing_cache)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "6fda0c0e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "<class 'torch.export.exported_program.ExportedProgram'>\n",
      "<class 'torch.fx.graph_module.GraphModule.__new__.<locals>.GraphModuleImpl'>\n"
     ]
    }
   ],
   "source": [
    "print(type(exp_program))\n",
    "print(type(trt_gm))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "5ed9e4c5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted: \"Ankle boot\", Actual: \"Ankle boot\"\n"
     ]
    }
   ],
   "source": [
    "stream = torch.cuda.Stream()\n",
    "with torch.no_grad(), torch.cuda.stream(stream):\n",
    "    trt_gm(torch.flatten(x.to(device), start_dim=1, end_dim=-1))\n",
    "    predicted, actual = classes[pred[0].argmax(0)], classes[y]\n",
    "    print(f'Predicted: \"{predicted}\", Actual: \"{actual}\"')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38697a06",
   "metadata": {},
   "source": [
    "We can run the optimized module with a few different batch sizes (without recompilation!):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27871156",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Output shapes:\n",
      "torch.Size([10, 10])\n",
      "torch.Size([1, 10])\n",
      "torch.Size([50, 10])\n"
     ]
    }
   ],
   "source": [
    "inputs = (torch.randn((10, 784), dtype=torch.float).cuda(),)\n",
    "inputs_bs1 = (torch.randn((1, 784), dtype=torch.float).cuda(),)\n",
    "inputs_bs50 = (torch.randn((50, 784), dtype=torch.float).cuda(),)\n",
    "\n",
    "stream = torch.cuda.Stream()\n",
    "with torch.no_grad(), torch.cuda.stream(stream):\n",
    "    print(\"Output shapes:\")\n",
    "    print(trt_gm(*inputs).shape)\n",
    "    print(trt_gm(*inputs_bs1).shape)\n",
    "    print(trt_gm(*inputs_bs50).shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ab974244",
   "metadata": {},
   "source": [
    "We can serialize the ExportedProgram (a traced graph representing the model's forward function) using `torch.export.save` to be recompiled at a later date."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d87e4b20",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Saved ExportedProgram to models/trt_model.ep\n"
     ]
    }
   ],
   "source": [
    "torch.export.save(exp_program, \"models/trt_model.ep\")\n",
    "print(\"Saved ExportedProgram to models/trt_model.ep\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ad918393",
   "metadata": {},
   "source": [
    "## PySpark"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "42c5feba",
   "metadata": {},
   "outputs": [],
   "source": [
    "from pyspark.sql.functions import col, struct, pandas_udf, array\n",
    "from pyspark.ml.functions import predict_batch_udf\n",
    "from pyspark.sql.types import *\n",
    "from pyspark.sql import SparkSession\n",
    "from pyspark import SparkConf"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "ef97321d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import json\n",
    "import os"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ece094d6",
   "metadata": {},
   "source": [
    "Check the cluster environment to handle any platform-specific Spark configurations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "10eb841f",
   "metadata": {},
   "outputs": [],
   "source": [
    "on_databricks = os.environ.get(\"DATABRICKS_RUNTIME_VERSION\", False)\n",
    "on_dataproc = os.environ.get(\"DATAPROC_IMAGE_VERSION\", False)\n",
    "on_standalone = not (on_databricks or on_dataproc)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "425e94ac",
   "metadata": {},
   "source": [
    "#### Create Spark Session\n",
    "\n",
    "For local standalone clusters, we'll connect to the cluster and create the Spark Session.  \n",
    "For CSP environments, Spark will either be preconfigured (Databricks) or we'll need to create the Spark Session (Dataproc)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "60ba6e74",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "25/02/04 13:50:47 WARN Utils: Your hostname, cb4ae00-lcedt resolves to a loopback address: 127.0.1.1; using 10.110.47.100 instead (on interface eno1)\n",
      "25/02/04 13:50:47 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
      "Setting default log level to \"WARN\".\n",
      "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
      "25/02/04 13:50:48 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n"
     ]
    }
   ],
   "source": [
    "conf = SparkConf()\n",
    "\n",
    "if 'spark' not in globals():\n",
    "    if on_standalone:\n",
    "        import socket\n",
    "        conda_env = os.environ.get(\"CONDA_PREFIX\")\n",
    "        hostname = socket.gethostname()\n",
    "        conf.setMaster(f\"spark://{hostname}:7077\")\n",
    "        conf.set(\"spark.pyspark.python\", f\"{conda_env}/bin/python\")\n",
    "        conf.set(\"spark.pyspark.driver.python\", f\"{conda_env}/bin/python\")\n",
    "\n",
    "    conf.set(\"spark.executor.cores\", \"8\")\n",
    "    conf.set(\"spark.task.resource.gpu.amount\", \"0.125\")\n",
    "    conf.set(\"spark.executor.resource.gpu.amount\", \"1\")\n",
    "    conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"true\")\n",
    "    conf.set(\"spark.python.worker.reuse\", \"true\")\n",
    "    \n",
    "spark = SparkSession.builder.appName(\"spark-dl-examples\").config(conf=conf).getOrCreate()\n",
    "sc = spark.sparkContext"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2cd11476",
   "metadata": {},
   "source": [
    "#### Create Spark DataFrame from Pandas DataFrame"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "f063cbe7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((10000, 28, 28), dtype('uint8'))"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = test_data.data.numpy()\n",
    "data.shape, data.dtype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8c828393",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((10000, 784), dtype('float64'))"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = data.reshape(10000, 784) / 255.0\n",
    "data.shape, data.dtype"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "7760bdbe",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "      <th>...</th>\n",
       "      <th>774</th>\n",
       "      <th>775</th>\n",
       "      <th>776</th>\n",
       "      <th>777</th>\n",
       "      <th>778</th>\n",
       "      <th>779</th>\n",
       "      <th>780</th>\n",
       "      <th>781</th>\n",
       "      <th>782</th>\n",
       "      <th>783</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>0.007843</td>\n",
       "      <td>0.011765</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.011765</td>\n",
       "      <td>0.682353</td>\n",
       "      <td>0.741176</td>\n",
       "      <td>0.262745</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.003922</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>0.643137</td>\n",
       "      <td>0.227451</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.082353</td>\n",
       "      <td>...</td>\n",
       "      <td>0.003922</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.007843</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.003922</td>\n",
       "      <td>0.003922</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>0.278431</td>\n",
       "      <td>0.047059</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9995</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9996</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.121569</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9997</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>0.105882</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9998</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9999</th>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>...</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "      <td>0.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>10000 rows × 784 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "      0    1    2         3    4         5         6    7         8    \\\n",
       "0     0.0  0.0  0.0  0.000000  0.0  0.000000  0.000000  0.0  0.000000   \n",
       "1     0.0  0.0  0.0  0.000000  0.0  0.000000  0.000000  0.0  0.000000   \n",
       "2     0.0  0.0  0.0  0.000000  0.0  0.000000  0.000000  0.0  0.003922   \n",
       "3     0.0  0.0  0.0  0.000000  0.0  0.000000  0.000000  0.0  0.000000   \n",
       "4     0.0  0.0  0.0  0.007843  0.0  0.003922  0.003922  0.0  0.000000   \n",
       "...   ...  ...  ...       ...  ...       ...       ...  ...       ...   \n",
       "9995  0.0  0.0  0.0  0.000000  0.0  0.000000  0.000000  0.0  0.000000   \n",
       "9996  0.0  0.0  0.0  0.000000  0.0  0.000000  0.000000  0.0  0.000000   \n",
       "9997  0.0  0.0  0.0  0.000000  0.0  0.000000  0.000000  0.0  0.000000   \n",
       "9998  0.0  0.0  0.0  0.000000  0.0  0.000000  0.000000  0.0  0.000000   \n",
       "9999  0.0  0.0  0.0  0.000000  0.0  0.000000  0.000000  0.0  0.000000   \n",
       "\n",
       "           9    ...       774       775  776       777       778       779  \\\n",
       "0     0.000000  ...  0.000000  0.000000  0.0  0.000000  0.000000  0.000000   \n",
       "1     0.000000  ...  0.007843  0.011765  0.0  0.011765  0.682353  0.741176   \n",
       "2     0.000000  ...  0.643137  0.227451  0.0  0.000000  0.000000  0.000000   \n",
       "3     0.082353  ...  0.003922  0.000000  0.0  0.000000  0.000000  0.000000   \n",
       "4     0.000000  ...  0.278431  0.047059  0.0  0.000000  0.000000  0.000000   \n",
       "...        ...  ...       ...       ...  ...       ...       ...       ...   \n",
       "9995  0.000000  ...  0.000000  0.000000  0.0  0.000000  0.000000  0.000000   \n",
       "9996  0.121569  ...  0.000000  0.000000  0.0  0.000000  0.000000  0.000000   \n",
       "9997  0.000000  ...  0.105882  0.000000  0.0  0.000000  0.000000  0.000000   \n",
       "9998  0.000000  ...  0.000000  0.000000  0.0  0.000000  0.000000  0.000000   \n",
       "9999  0.000000  ...  0.000000  0.000000  0.0  0.000000  0.000000  0.000000   \n",
       "\n",
       "           780  781  782  783  \n",
       "0     0.000000  0.0  0.0  0.0  \n",
       "1     0.262745  0.0  0.0  0.0  \n",
       "2     0.000000  0.0  0.0  0.0  \n",
       "3     0.000000  0.0  0.0  0.0  \n",
       "4     0.000000  0.0  0.0  0.0  \n",
       "...        ...  ...  ...  ...  \n",
       "9995  0.000000  0.0  0.0  0.0  \n",
       "9996  0.000000  0.0  0.0  0.0  \n",
       "9997  0.000000  0.0  0.0  0.0  \n",
       "9998  0.000000  0.0  0.0  0.0  \n",
       "9999  0.000000  0.0  0.0  0.0  \n",
       "\n",
       "[10000 rows x 784 columns]"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pdf784 = pd.DataFrame(data)\n",
    "pdf784"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "f7d2bc0d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>data</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.00784313725490196, 0.0, 0.00...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9995</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9996</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9997</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9998</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>9999</th>\n",
       "      <td>[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>10000 rows × 1 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "                                                   data\n",
       "0     [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n",
       "1     [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n",
       "2     [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.003...\n",
       "3     [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n",
       "4     [0.0, 0.0, 0.0, 0.00784313725490196, 0.0, 0.00...\n",
       "...                                                 ...\n",
       "9995  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n",
       "9996  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n",
       "9997  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n",
       "9998  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n",
       "9999  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...\n",
       "\n",
       "[10000 rows x 1 columns]"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 1 column of array<float>\n",
    "pdf1 = pd.DataFrame()\n",
    "pdf1['data'] = pdf784.values.tolist()\n",
    "pdf1"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07b2a70b",
   "metadata": {},
   "source": [
    "Create dataframes with a single column of 784 floats and 784 separate columns."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "4863d5ff",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 185 ms, sys: 28.9 ms, total: 214 ms\n",
      "Wall time: 1.5 s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "StructType([StructField('data', ArrayType(FloatType(), True), True)])"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "# force FloatType since Spark defaults to DoubleType\n",
    "schema = StructType([StructField(\"data\",ArrayType(FloatType()), True)])\n",
    "df = spark.createDataFrame(pdf1, schema).repartition(8)\n",
    "df.schema"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "831f4a01-3a49-4114-b9a0-2ae54526d72d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 66.9 ms, sys: 11.2 ms, total: 78.1 ms\n",
      "Wall time: 875 ms\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "StructType([StructField('data', ArrayType(FloatType(), True), True)])"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%%time\n",
    "# force FloatType since Spark defaults to DoubleType\n",
    "schema = StructType([StructField(str(x), FloatType(), True) for x in range(784)])\n",
    "df784 = spark.createDataFrame(pdf784, schema).repartition(8)\n",
    "df.schema"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "e8ebae46",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "25/02/04 13:50:51 WARN TaskSetManager: Stage 0 contains a task of very large size (4030 KiB). The maximum recommended task size is 1000 KiB.\n",
      "[Stage 0:=======>                                                   (1 + 7) / 8]\r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.09 ms, sys: 1.6 ms, total: 3.69 ms\n",
      "Wall time: 1.71 s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "data_path_1 = \"spark-dl-datasets/fashion_mnist_1\"\n",
    "if on_databricks:\n",
    "    dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n",
    "    data_path_1 = \"dbfs:/FileStore/\" + data_path_1\n",
    "\n",
    "df.write.mode(\"overwrite\").parquet(data_path_1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "922314ce-2996-4666-9fc9-bcd98d16bb56",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "25/02/04 13:50:53 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "25/02/04 13:50:53 WARN TaskSetManager: Stage 3 contains a task of very large size (7847 KiB). The maximum recommended task size is 1000 KiB.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.94 ms, sys: 61 μs, total: 3 ms\n",
      "Wall time: 943 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "data_path_784 = \"spark-dl-datasets/fashion_mnist_784\"\n",
    "if on_databricks:\n",
    "    dbutils.fs.mkdirs(\"/FileStore/spark-dl-datasets\")\n",
    "    data_path_784 = \"dbfs:/FileStore/\" + data_path_784\n",
    "\n",
    "df784.write.mode(\"overwrite\").parquet(data_path_784)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fce89cb0",
   "metadata": {},
   "source": [
    "## Inference using Spark DL API\n",
    "\n",
    "Distributed inference using the PySpark [predict_batch_udf](https://spark.apache.org/docs/latest/api/python/reference/api/pyspark.ml.functions.predict_batch_udf.html#pyspark.ml.functions.predict_batch_udf):\n",
    "\n",
    "- predict_batch_fn uses PyTorch APIs to load the model and return a predict function which operates on numpy arrays \n",
    "- predict_batch_udf will convert the Spark DataFrame columns into numpy input batches for the predict function"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59395856-a588-43c6-93c8-c83100716ac1",
   "metadata": {
    "tags": []
   },
   "source": [
    "### 1 column of 784 float"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "79b151d9-d112-43b6-a479-887e2fd0e2b1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = spark.read.parquet(data_path_1)\n",
    "len(df.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "3e6a4dbb",
   "metadata": {},
   "outputs": [],
   "source": [
    "# A resource warning may occur due to unclosed file descriptors used by TensorRT across multiple PySpark daemon processes.\n",
    "# These can be safely ignored as the resources will be cleaned up when the worker processes terminate.\n",
    "\n",
    "import warnings\n",
    "warnings.simplefilter(\"ignore\", ResourceWarning)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "16e523c2",
   "metadata": {},
   "outputs": [],
   "source": [
    "# get absolute path to model\n",
    "model_path = \"{}/models/trt_model.ep\".format(os.getcwd())\n",
    "\n",
    "# For cloud environments, copy the model to the distributed file system.\n",
    "if on_databricks:\n",
    "    dbutils.fs.mkdirs(\"/FileStore/spark-dl-models\")\n",
    "    dbfs_model_path = \"/dbfs/FileStore/spark-dl-models/model.pt\"\n",
    "    shutil.copy(model_path, dbfs_model_path)\n",
    "    model_path = dbfs_model_path\n",
    "elif on_dataproc:\n",
    "    # GCS is mounted at /mnt/gcs by the init script\n",
    "    models_dir = \"/mnt/gcs/spark-dl/models\"\n",
    "    os.mkdir(models_dir) if not os.path.exists(models_dir) else None\n",
    "    gcs_model_path = models_dir + \"/trt_model.ep\"\n",
    "    shutil.copy(model_path, gcs_model_path)\n",
    "    model_path = gcs_model_path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "73dc73cb-25e3-4798-a019-e1abd684eaa1",
   "metadata": {},
   "outputs": [],
   "source": [
    "def predict_batch_fn():\n",
    "    import torch\n",
    "    import torch_tensorrt as trt\n",
    "    \n",
    "    device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "    if device != \"cuda\":\n",
    "        raise ValueError(\"This function uses the TensorRT model which requires a GPU device\")\n",
    "\n",
    "    example_inputs = (torch.randn((50, 784), dtype=torch.float).to(\"cuda\"),)\n",
    "    exp_program = torch.export.load(model_path)\n",
    "    trt_gm = trt.dynamo.compile(exp_program,\n",
    "                                tuple(example_inputs),\n",
    "                                enabled_precisions={torch.float},\n",
    "                                workspace_size=1<<30)\n",
    "\n",
    "    def predict(inputs: np.ndarray):\n",
    "        stream = torch.cuda.Stream()\n",
    "        with torch.no_grad(), torch.cuda.stream(stream):\n",
    "            # use array to combine columns into tensors\n",
    "            torch_inputs = torch.from_numpy(inputs).to(device)\n",
    "            outputs = trt_gm(torch_inputs)\n",
    "            return outputs.detach().cpu().numpy()\n",
    "\n",
    "    return predict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "df68cca1-2d47-4e88-8aad-9899402aee97",
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist = predict_batch_udf(predict_batch_fn,\n",
    "                          input_tensor_shapes=[[784]],\n",
    "                          return_type=ArrayType(FloatType()),\n",
    "                          batch_size=50)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "63555b3b-3673-4712-97aa-fd728c6c4979",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 167 ms, sys: 76.2 ms, total: 243 ms\n",
      "Wall time: 10.9 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "# first pass compiles and caches model/fn\n",
    "preds = df.withColumn(\"preds\", mnist(struct(df.columns))).collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "5dbf058a-70d6-4199-af9d-13843d078950",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 234 ms, sys: 64.1 ms, total: 298 ms\n",
      "Wall time: 685 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "preds = df.withColumn(\"preds\", mnist(*df.columns)).collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "id": "3f5ed801-6ca5-43a0-bf9c-2535a0dfe2e8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 403 ms, sys: 60.1 ms, total: 463 ms\n",
      "Wall time: 809 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "preds = df.withColumn(\"preds\", mnist(*[col(c) for c in df.columns])).collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6dbec03-9b64-46c4-a748-f889be571384",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Check predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "id": "f1f1e5fd-5866-4b78-b9d3-709e6b383a0c",
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions = preds[0].preds\n",
    "img = preds[0].data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "76b76502-adb7-45ec-a365-2e61cdd576fc",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "id": "c163953a-1504-444f-b39f-86b61d34e440",
   "metadata": {},
   "outputs": [],
   "source": [
    "img = np.array(img).reshape(28,28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "bc0fad05-50ab-4ae5-b9fd-e50133c4c92a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAJJ5JREFUeJzt3Xt0lfWd7/HPTkg2t2THEHKTQAMotHLplErKqBRLDpDOuEA5HW9zDrg6MNLgqlKrJz1Watuz0uIa66lDca2zWqir4oVzREbHYgUljAp0QBjGXlLAKGEgoWCTDQlJdrJ/5w/GzERB+P5M8kvC+7XWXovs/Xx4fnnyJJ882TvfRJxzTgAA9LKU0AsAAFyaKCAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQQwKvYAPSyaTOnr0qDIyMhSJREIvBwBg5JzTqVOnVFhYqJSU81/n9LkCOnr0qIqKikIvAwDwCdXW1mrUqFHnfbzPFVBGRoYk6Vp9WYOUFng16Ha9dVXLhCkgmHYl9Lpe6vx6fj49VkCrV6/Www8/rLq6Ok2dOlWPPfaYpk+ffsHcBz92G6Q0DYpQQANOr/1YlQICgvn3T78LPY3SIy9CeOaZZ7RixQqtXLlSb731lqZOnaq5c+fq+PHjPbE7AEA/1CMF9Mgjj2jJkiW644479JnPfEaPP/64hg4dqp/97Gc9sTsAQD/U7QXU1tamPXv2qLS09D92kpKi0tJS7dix4yPbt7a2Kh6Pd7kBAAa+bi+gEydOqKOjQ3l5eV3uz8vLU11d3Ue2r6ysVCwW67zxCjgAuDQE/0XUiooKNTY2dt5qa2tDLwkA0Au6/VVwOTk5Sk1NVX19fZf76+vrlZ+f/5Hto9GootFody8DANDHdfsVUHp6uqZNm6atW7d23pdMJrV161bNmDGju3cHAOineuT3gFasWKFFixbp85//vKZPn65HH31UTU1NuuOOO3pidwCAfqhHCujmm2/WH//4Rz344IOqq6vTZz/7WW3evPkjL0wAAFy6Is71rZkl8XhcsVhMszSfSQgYsFLHF5szx/7O/lxp7vzfmzPAJ9XuEtqmTWpsbFRmZuZ5twv+KjgAwKWJAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEH0yDRsoL+q+YH9b1bdP3+jOVPVcP4BjeczPu2MOZO/356RpF9smG3OFH3vTa994dLFFRAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCYBp2HxYZZP/wuI4O+46cs2c8RdLSzRmXaDNnBhWPMWckacttD5szX/zV3ebMlX+z25ypNyek3Tdf75GSHvzuU+bM2u/5HXOzSMSe6cVzHBePKyAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACCLiXN+a0hePxxWLxTRL8zUokhZ6Od3HY4BiJDXVnPEaRuqrb506Xfxh7TSv3NiiP5ozg0oPe+2rLzvzcrE5c8Pl+82ZLZMyzBn0fe0uoW3apMbGRmVmZp53O66AAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACCIQaEXcMnwGNzp2tvt+/EYetqXh4pKUsrUT5szW770v732VfrLFebMlfIYRppiHzQbSbF/bL3OIUlpP8w2Z25e9y/mzIbF3zRnLlu3w5xB38QVEAAgCAoIABBEtxfQd77zHUUikS63iRMndvduAAD9XI88B3TVVVdpy5Yt/7GTQTzVBADoqkeaYdCgQcrPz++J/xoAMED0yHNABw4cUGFhocaOHavbb79dhw+f/1VCra2tisfjXW4AgIGv2wuopKRE69at0+bNm7VmzRrV1NTouuuu06lTp865fWVlpWKxWOetqKiou5cEAOiDur2AysrK9JWvfEVTpkzR3Llz9dJLL6mhoUHPPvvsObevqKhQY2Nj5622tra7lwQA6IN6/NUBWVlZuvLKK3Xw4MFzPh6NRhWNRnt6GQCAPqbHfw/o9OnTOnTokAoKCnp6VwCAfqTbC+jee+9VVVWV3n33Xb355pu68cYblZqaqltvvbW7dwUA6Me6/UdwR44c0a233qqTJ09q5MiRuvbaa7Vz506NHDmyu3cFAOjHur2Ann766e7+L2HhM1jUYzCmJCnZYY7Eb/2COTNm+R/MmcdPXmfOSFLhq700ncolzZFIdKh9N57DSOu+YH9e9p1EpjnzzEMPmzMr//bL5kz9DH69oy9iFhwAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABNHjf5Cu10Qi9ozP4M7e5DMk1GPIpc9QUV9Xf2OPOVPdmGfONLenmzOSNPzZnV45q0iq5wDYXpJ2yp55s+kKc+bRP33KnLlr1BZz5oHblpgzkpS53uN86K2vRT778d1XD+EKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEEMnGnYfZ3H5Fqficku0XuTrYdtH2nOtDv7mOXUFPuE73c3jTVnJKlAdV45K9fh8XFqS3T/Qs4j77E3zZlvVVSbM9cevcqcWXlgvjlz8//cbM5I0ssvFJkzyVMeo8R9+E617kN/OYArIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIYuAMI/UYlhdJS/fbVaLNI2Rfn9d+PBy978+9ct/MfdacefLfvmDOJGUfnljwiH2YZq/qw+eDr181p5kz/33MTnNm1d455kxqkd8wzYxf2r9GNF7rtaveE/G47nA9M+SYKyAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACCLinMdUxB4Uj8cVi8U0S/M1KGIfbjiQHF9uHxLammXfz6+WrrKHJK1r+Lw5Myr9fXPme7+8yZyJ/cE+wFSSrv+bXebMxt981pwZ9G9RcyalzeN9ivh9evvsq3VE0pwZOeGEOROLtpgzZ9r9vpasHP8P5syyDUvNmeL/scOc6cvaXULbtEmNjY3KzMw873ZcAQEAgqCAAABBmAto+/btuuGGG1RYWKhIJKLnn3++y+POOT344IMqKCjQkCFDVFpaqgMHDnTXegEAA4S5gJqamjR16lStXr36nI+vWrVKP/7xj/X4449r165dGjZsmObOnauWFvvPbQEAA5f5L6KWlZWprKzsnI855/Too4/qgQce0Pz58yVJTzzxhPLy8vT888/rlltu+WSrBQAMGN36HFBNTY3q6upUWlraeV8sFlNJSYl27Dj3qzxaW1sVj8e73AAAA1+3FlBdXZ0kKS8vr8v9eXl5nY99WGVlpWKxWOetqKioO5cEAOijgr8KrqKiQo2NjZ232tra0EsCAPSCbi2g/Px8SVJ9fX2X++vr6zsf+7BoNKrMzMwuNwDAwNetBVRcXKz8/Hxt3bq18754PK5du3ZpxowZ3bkrAEA/Z34V3OnTp3Xw4MHOt2tqarRv3z5lZ2dr9OjRuvvuu/X9739fV1xxhYqLi/Xtb39bhYWFWrBgQXeuGwDQz5kLaPfu3br++us7316xYoUkadGiRVq3bp3uu+8+NTU1aenSpWpoaNC1116rzZs3a/Dgwd23agBAv3dJDyN9Z5XfjwX/7safmzMr/vmvzJn09HZz5gdTnzNnNjdMMWckKUX2U+d3jXkX3uhD3tt7uTnTcVnCnJGkwbXp5kzmO/bjkNJuz3Sk2weEJs3fYp7lUu2ZZJp9fZGk/Ticntlszny26Ig5I0lHT8fMmWvy3jFn3nrf/urfI+9nmTOSNPor/+qVs2AYKQCgT6OAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAIz1m5A8Nvb/97r9zCg39hzuT8o/3PUZxeeMqc+dnR68yZxja/P5VxR9Eb5syJtmHmTM3gpDmjDvtkZklqu8y+r8RX/mTOjIo1mjMjo6fNmWiqfaK6JGUNsk+cTniM0G7qiJozMzOr7ftJ2vcjSW8OGm/OpMp+Dg0b1GbOvDR9jTkjSbfefq85E3typ9e+LoQrIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIYsAMI225Ybo5kxbZ57evb+WbMzn/6z1z5uXxz5kzD5+wH4ehKfZBiJK08vUF5kxK3H7KuSyPgZoe80vP7ithzsQPXGbOHGwYYc7U2meeKrXV2UOeOqL2AbDOY2bs6+mfM2duX/yKfUeSrsv6gznzucGHzZmX064yZ/7in+80ZyTpbyp+Zc68/GSm174uhCsgAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAhiwAwjrf98770rWT+sNWf+MudfzJmfNtgHFP5V1j+bM9+t/UtzRpIu/2WqOZMY6jF9UmnmRCTpN4TTpdjX15Fu308yzb4+n7W1Zvkcb0kesYjHzFif/Qz/N/uk2cf/6Xr7jiT9Yf4ac+Y3bfZ3alFsvznzj5mTzRlJ+uvYv5ozv/qzpabtIx2t0r9suuB2XAEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBADZhhp2une29e9l282Zzb8abo5k5seN2e+eei/mjPNf3+5OSNJZ3Lt37/4DMdMbTNHlEz1G8KZ0ksDNX0Gdzr77FevtUlS+xC/nJXPcTg12n7eFVTZ9yNJX5s205wpHNxgzlSfzjNnFl6+15yRpBEp9g/umcuHmbZvT6RKFzF/mSsgAEAQFBAAIAhzAW3fvl033HCDCgsLFYlE9Pzzz3d5fPHixYpEIl1u8+bN6671AgAGCHMBNTU1aerUqVq9evV5t5k3b56OHTvWeXvqqac+0SIBAAOP+UUIZWVlKisr+9htotGo8vPzvRcFABj4euQ5oG3btik3N1cTJkzQsmXLdPLkyfNu29raqng83uUGABj4ur2A5s2bpyeeeEJbt27VD3/4Q1VVVamsrEwdHR3n3L6yslKxWKzzVlRU1N1LAgD0Qd3+e0C33HJL578nT56sKVOmaNy4cdq2bZtmz579ke0rKiq0YsWKzrfj8TglBACXgB5/GfbYsWOVk5OjgwcPnvPxaDSqzMzMLjcAwMDX4wV05MgRnTx5UgUFBT29KwBAP2L+Edzp06e7XM3U1NRo3759ys7OVnZ2th566CEtXLhQ+fn5OnTokO677z6NHz9ec+fO7daFAwD6N3MB7d69W9dff33n2x88f7No0SKtWbNG+/fv189//nM1NDSosLBQc+bM0fe+9z1Fo9HuWzUAoN8zF9CsWbPknDvv4y+//PInWpCvQc29t68DbfbfcXo43z448P+dtj8f1viz/2LOJGN+EysTw+y5yLlfDPmxOtLtGd8hnB9zap9/V700WNRrQKjncRh0xp5J8Rga63M++By7jqjfgfj1+qnmzIYVD5sz/5Q+zpy5esi75owkxZNJc2ZY9QnT9u0drRe1HbPgAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEES3/0nuUIacsE949TUu7bg5c6TdPl74vpfuNGcyRti/p0gMM0ckSelxe8Z5fMvTMdiekcdUa0lqH+qxK4/pzD7r85m67Ssx3L7ApMdXk9Q2+5TqlIsbtNxF/FN+07AzDtuPw7cOzzdn/u+4LebMtjMeJ6ukwtRT5kzHgXds27vERW3HFRAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABDFghpFmHjptziRch9e+xqa1mDMzdywzZ/LfsA9C/NNEc8R7yGVLjj3TEbW/T9H3PQZJen5r5TMstX2o/X1qz7CfeynDL27A43+WbPWZlCpFj6aZM2mn7R8nr2PnMZw2tdVvGOmZXHvunfVXmDO/u/8fzBlpuEdGuixliDmTOmG8aXvX0SoduPB2XAEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBADZhhpSnObOZMW8RvU+K9tmebM8F/ZBwf+aaLHAMWkPdIxxD4QUpLasuw7i56wH/P0U/b1pS84bs5I0snGYeZMR5v90yj9cNScyd5u/34x4veh1Zls+7nXXOgxWNRjGKl85or6zSJVIsO+viEp9o/T/J13mjO/nPETc0aS2mU/9yIJ28TiSPLitucKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCGDDDSNXeYY40Js947epvd37NnEkbZZ+G2Jpjf59SWu37iST8JjWm5zebM2N/FDdn2t87Ys7UR0vMGUkav/WkPXS01p7JHWGOvHdTrjnTPMY2RPIDkaH24b7ujMdw36TH+dpuz3Sk+k1ldYPsueZR9kzGTvuw4rrpQ80ZSRqXZr/uaH/nXdv2LnFR23EFBAAIggICAARhKqDKykpdffXVysjIUG5urhYsWKDq6uou27S0tKi8vFwjRozQ8OHDtXDhQtXX13frogEA/Z+pgKqqqlReXq6dO3fqlVdeUSKR0Jw5c9TU1NS5zT333KMXXnhBGzZsUFVVlY4ePaqbbrqp2xcOAOjfTC9C2Lx5c5e3161bp9zcXO3Zs0czZ85UY2OjfvrTn2r9+vX60pe+JElau3atPv3pT2vnzp36whe+0H0rBwD0a5/oOaDGxkZJUnZ2tiRpz549SiQSKi0t7dxm4sSJGj16tHbs2HHO/6O1tVXxeLzLDQAw8HkXUDKZ1N13361rrrlGkyZNkiTV1dUpPT1dWVlZXbbNy8tTXV3dOf+fyspKxWKxzltRUZHvkgAA/Yh3AZWXl+vtt9/W008//YkWUFFRocbGxs5bba3H71QAAPodr19EXb58uV588UVt375do0aN6rw/Pz9fbW1tamho6HIVVF9fr/z8/HP+X9FoVNFo1GcZAIB+zHQF5JzT8uXLtXHjRr366qsqLi7u8vi0adOUlpamrVu3dt5XXV2tw4cPa8aMGd2zYgDAgGC6AiovL9f69eu1adMmZWRkdD6vE4vFNGTIEMViMX31q1/VihUrlJ2drczMTN11112aMWMGr4ADAHRhKqA1a9ZIkmbNmtXl/rVr12rx4sWSpB/96EdKSUnRwoUL1draqrlz5+onP/lJtywWADBwmArIuQsP2Rs8eLBWr16t1atXey/KS6r99RQvNo268Ebn4JIeGY95n6nNHkMDh3sMMI34vRYl2Wp/CrH5ypHmTHrNe+bM5c8eMmckqf4vxpozJ6/JMGeyRpw2Z1pP24fnRt5PN2ckSQ1p9n35zLT1OfV8Mn6zSBVJ2HfmhtsHwDYX2vfz11v+1pyRpJq//D/mTOqIbNP2LtkmvX/h7ZgFBwAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCC8/iJqX9TxuwPmzLCUVq99PXHNT82Zv04sNWcizanmTGosYc4kO/wmJrsW++kTX95ozqTcdaU509xqn+YsSc6dsofeH2KONL6bZc5E7IPOpTTPMdD2U89r4rRL9Qh5TKOP+Iyjl+QG2w96aoP986JjmP2dSq/vvS/fLX9WfOGN/pP29hbptQtvxxUQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAAQxYIaR+shKafbKpXpMXfzRdU+bM4WD/mTO/OLkn5szL74xzZyRJHXYBzy+fyTLnElttn+f5HrxW6uIx0BNl2YfPun8ZsZ6ibR7DO/0mffpMyvVYz8uxXMoq8c5nox67ssoJeE3YLU52WbOtGXZqqI9cXHbcwUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFc0sNIt5y6yis3Zehhc+Zyj8GiE9LazZlPDT5pzvy3Wf9kzkjS+t9cbQ8dHmKOpNgPgxIx+7BPSYp4DJ+MJO2ZlDN+gyStXO/sRpIU8ZjB6VLsC/QZNOuztrP78nmnfM4h+24GtdgzklTT3mHODD6RMG3f3n5x23MFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBXNrDSI9O8MrFRjebM+PS/mjO/CJ+pTmTFrEPGizL2G/OSNKUz9eaMyNKTpszean2zMb4n5kzkvTSUfuA2vak/fu45rY0cybpsZ9omm2I5AeGeAzC9TkOg1LsUzh95oqebol6pKTYEPvEz8LhjeZMW0eqOZP0mcoq6Y0z48yZYzMGm7bvaJV0ETOOuQICAARBAQEAgjAVUGVlpa6++mplZGQoNzdXCxYsUHV1dZdtZs2apUgk0uV25513duuiAQD9n6mAqqqqVF5erp07d+qVV15RIpHQnDlz1NTU1GW7JUuW6NixY523VatWdeuiAQD9n+lFCJs3b+7y9rp165Sbm6s9e/Zo5syZnfcPHTpU+fn53bNCAMCA9ImeA2psPPtqj+zs7C73P/nkk8rJydGkSZNUUVGh5ubzv2qstbVV8Xi8yw0AMPB5vww7mUzq7rvv1jXXXKNJkyZ13n/bbbdpzJgxKiws1P79+3X//ferurpazz333Dn/n8rKSj300EO+ywAA9FPeBVReXq63335br7/+epf7ly5d2vnvyZMnq6CgQLNnz9ahQ4c0btxHX39eUVGhFStWdL4dj8dVVFTkuywAQD/hVUDLly/Xiy++qO3bt2vUqFEfu21JSYkk6eDBg+csoGg0qmjU75fEAAD9l6mAnHO66667tHHjRm3btk3FxcUXzOzbt0+SVFBQ4LVAAMDAZCqg8vJyrV+/Xps2bVJGRobq6uokSbFYTEOGDNGhQ4e0fv16ffnLX9aIESO0f/9+3XPPPZo5c6amTJnSI+8AAKB/MhXQmjVrJJ39ZdP/bO3atVq8eLHS09O1ZcsWPfroo2pqalJRUZEWLlyoBx54oNsWDAAYGMw/gvs4RUVFqqqq+kQLAgBcGi7padhDPCcFL8v6jTlzMBExZ8qz7NOm/dgn8Z5l/52t5mSbOTM0Zag58+mc6gtvdA7LLttrzlyWal+fjzda7JOjG5J+a8tPtX9sB3tMYu+Q/fOixdnP15EpreaMJBWnDTdn9rTaz/HxafZjt6Mly5yRpPV/LDFnRlW+adq+3SV04CK2YxgpACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARxaQ8j/brfX2Itm/B1cybaYB98mkzrne8POqJ++zlyvT0XybMPhcx4c4g5U/jLo+aMJLlU+/vUcdkwcyb1VIs5o2PHzRGXaLfvR1JkqH2IaWS4x+DTC0zYP6d2++BOX81X2f+Qps/n7fA9h82Z9mN15sxZ9kGzPYUrIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEESfmwXn/n02VLsSkseYKNO+OuxzySSpPWGf45Xa7jELLtJLs+BS/PaTbPGYBddsP+YdbRFzpj3p97F1Ht+TdbSn2vfjc+65NnvEec6CS9q/NESS9uPgNQsu2Xuz4Nrb7Z/rSY9zqD1p/9i2O/vXlN7SrrNrcxf4+EbchbboZUeOHFFRUVHoZQAAPqHa2lqNGjXqvI/3uQJKJpM6evSoMjIyFIl0/c43Ho+rqKhItbW1yszMDLTC8DgOZ3EczuI4nMVxOKsvHAfnnE6dOqXCwkKlfMxPWPrcj+BSUlI+tjElKTMz85I+wT7AcTiL43AWx+EsjsNZoY9DLBa74Da8CAEAEAQFBAAIol8VUDQa1cqVKxWN+v0l04GC43AWx+EsjsNZHIez+tNx6HMvQgAAXBr61RUQAGDgoIAAAEFQQACAICggAEAQ/aaAVq9erU996lMaPHiwSkpK9Otf/zr0knrdd77zHUUikS63iRMnhl5Wj9u+fbtuuOEGFRYWKhKJ6Pnnn+/yuHNODz74oAoKCjRkyBCVlpbqwIEDYRbbgy50HBYvXvyR82PevHlhFttDKisrdfXVVysjI0O5ublasGCBqquru2zT0tKi8vJyjRgxQsOHD9fChQtVX18faMU942KOw6xZsz5yPtx5552BVnxu/aKAnnnmGa1YsUIrV67UW2+9palTp2ru3Lk6fvx46KX1uquuukrHjh3rvL3++uuhl9TjmpqaNHXqVK1evfqcj69atUo//vGP9fjjj2vXrl0aNmyY5s6dq5YW+yDJvuxCx0GS5s2b1+X8eOqpp3pxhT2vqqpK5eXl2rlzp1555RUlEgnNmTNHTU1Nndvcc889euGFF7RhwwZVVVXp6NGjuummmwKuuvtdzHGQpCVLlnQ5H1atWhVoxefh+oHp06e78vLyzrc7OjpcYWGhq6ysDLiq3rdy5Uo3derU0MsISpLbuHFj59vJZNLl5+e7hx9+uPO+hoYGF41G3VNPPRVghb3jw8fBOecWLVrk5s+fH2Q9oRw/ftxJclVVVc65sx/7tLQ0t2HDhs5tfve73zlJbseOHaGW2eM+fBycc+6LX/yi+/rXvx5uURehz18BtbW1ac+ePSotLe28LyUlRaWlpdqxY0fAlYVx4MABFRYWauzYsbr99tt1+PDh0EsKqqamRnV1dV3Oj1gsppKSkkvy/Ni2bZtyc3M1YcIELVu2TCdPngy9pB7V2NgoScrOzpYk7dmzR4lEosv5MHHiRI0ePXpAnw8fPg4fePLJJ5WTk6NJkyapoqJCzc3NIZZ3Xn1uGOmHnThxQh0dHcrLy+tyf15enn7/+98HWlUYJSUlWrdunSZMmKBjx47poYce0nXXXae3335bGRkZoZcXRF1dnSSd8/z44LFLxbx583TTTTepuLhYhw4d0re+9S2VlZVpx44dSk31+Fs9fVwymdTdd9+ta665RpMmTZJ09nxIT09XVlZWl20H8vlwruMgSbfddpvGjBmjwsJC7d+/X/fff7+qq6v13HPPBVxtV32+gPAfysrKOv89ZcoUlZSUaMyYMXr22Wf11a9+NeDK0Bfccsstnf+ePHmypkyZonHjxmnbtm2aPXt2wJX1jPLycr399tuXxPOgH+d8x2Hp0qWd/548ebIKCgo0e/ZsHTp0SOPGjevtZZ5Tn/8RXE5OjlJTUz/yKpb6+nrl5+cHWlXfkJWVpSuvvFIHDx4MvZRgPjgHOD8+auzYscrJyRmQ58fy5cv14osv6rXXXuvy51vy8/PV1tamhoaGLtsP1PPhfMfhXEpKSiSpT50Pfb6A0tPTNW3aNG3durXzvmQyqa1bt2rGjBkBVxbe6dOndejQIRUUFIReSjDFxcXKz8/vcn7E43Ht2rXrkj8/jhw5opMnTw6o88M5p+XLl2vjxo169dVXVVxc3OXxadOmKS0trcv5UF1drcOHDw+o8+FCx+Fc9u3bJ0l963wI/SqIi/H000+7aDTq1q1b537729+6pUuXuqysLFdXVxd6ab3qG9/4htu2bZurqalxb7zxhistLXU5OTnu+PHjoZfWo06dOuX27t3r9u7d6yS5Rx55xO3du9e99957zjnnfvCDH7isrCy3adMmt3//fjd//nxXXFzszpw5E3jl3evjjsOpU6fcvffe63bs2OFqamrcli1b3Oc+9zl3xRVXuJaWltBL7zbLli1zsVjMbdu2zR07dqzz1tzc3LnNnXfe6UaPHu1effVVt3v3bjdjxgw3Y8aMgKvufhc6DgcPHnTf/e533e7du11NTY3btGmTGzt2rJs5c2bglXfVLwrIOecee+wxN3r0aJeenu6mT5/udu7cGXpJve7mm292BQUFLj093V1++eXu5ptvdgcPHgy9rB732muvOUkfuS1atMg5d/al2N/+9rddXl6ei0ajbvbs2a66ujrsonvAxx2H5uZmN2fOHDdy5EiXlpbmxowZ45YsWTLgvkk71/svya1du7ZzmzNnzrivfe1r7rLLLnNDhw51N954ozt27Fi4RfeACx2Hw4cPu5kzZ7rs7GwXjUbd+PHj3Te/+U3X2NgYduEfwp9jAAAE0eefAwIADEwUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACOL/AyBNQnoqGwl/AAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure()\n",
    "plt.imshow(img)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "id": "56f36efb-e3a2-49f9-b9fb-1657bc25e5c5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-1.0776339769363403, -3.4281859397888184, 1.0321333408355713, -2.1151161193847656, 0.7665405869483948, 0.7089913487434387, 0.6775667071342468, 0.3138602077960968, 2.9969606399536133, 0.7927607297897339]\n",
      "predicted label: Bag\n"
     ]
    }
   ],
   "source": [
    "print(predictions)\n",
    "print(\"predicted label:\", classes[np.argmax(predictions)])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "56ca1195-ea0f-405f-87fe-857e5c0c76a5",
   "metadata": {},
   "source": [
    "### 784 columns of float"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "id": "e0ab0af6-b5c9-4b74-9dd6-baa7737cc986",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "784"
      ]
     },
     "execution_count": 53,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = spark.read.parquet(data_path_784)\n",
    "len(df.columns)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "13ae45dc-85a0-4864-8a58-9dc29ae4efd7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 225 ms, sys: 91.1 ms, total: 316 ms\n",
      "Wall time: 3.16 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "preds = df.withColumn(\"preds\", mnist(struct(df.columns))).collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "0b3fb48b-f871-41f2-ac57-346899a6fe48",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 283 ms, sys: 67.8 ms, total: 351 ms\n",
      "Wall time: 1.47 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "preds = df.withColumn(\"preds\", mnist(array(*df.columns))).collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "id": "b59114ad",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 543 ms, sys: 65.1 ms, total: 608 ms\n",
      "Wall time: 1.36 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "preds = df.withColumn(\"preds\", mnist(array(*df.columns))).collect()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "dc48ec42-0df6-4e6a-b019-1270ab71d2cf",
   "metadata": {
    "tags": []
   },
   "source": [
    "### Check predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "id": "d815c701-9f5b-422c-b3f9-fbc30456953c",
   "metadata": {},
   "outputs": [],
   "source": [
    "preds = df.withColumn(\"preds\", mnist(array(*df.columns))).limit(10).toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "id": "b571b742-5079-42b2-8524-9181a0dec2c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample = preds.iloc[0]\n",
    "predictions = sample.preds\n",
    "img = sample.drop('preds').to_numpy(dtype=float)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "id": "d33d6a4e-e6b9-489d-ac21-c4eddc801784",
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "6d10061e-aca6-4f81-bdfe-72e327ed7349",
   "metadata": {},
   "outputs": [],
   "source": [
    "img = np.array(img).reshape(28,28)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "id": "01f70e08-2c1d-419f-8676-3f6f4aba760f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAHpRJREFUeJzt3X1wlPW99/HPZpNsAoSNIeSpBBpQoZWHnlJJuVWKJQOkczyiTMenP8BxYLTBKVKrk46K2s6kxRnr6FA8f7RQ71t8mhEYPb3pKJowtoEOKDeH0zZCTix4QoLS5oGEPJD9nT84bu+FAP1dbPLdhPdr5prJ7l7fXF+uvchnr+y134Scc04AAAyzNOsGAABXJgIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJtKtGzhXLBZTc3OzcnJyFAqFrNsBAHhyzqmzs1MlJSVKS7vweU7KBVBzc7NKS0ut2wAAXKZjx45p0qRJF3w85QIoJydHknSjvqN0ZRh3g4tJn/wl75qy/33Cu2b3f03zrsnKOONdI0lnBobnt9IDbnjO7mOxYP+ejPCAd01He7Z3zdypR71r2v+527vG9fZ61yC4M+rXB/pN/Of5hQxZAG3cuFHPPPOMWlpaNGfOHL3wwguaN2/eJeu++LVbujKUHiKAUll6WsS7JnOc/3MaHuO/nXBG2LtGktwwBZCGKYBCAQMoHCCA0vqyvGsyxmZ616SH+r1rXCjmXYPL8D8TRi/1NsqQ/G977bXXtG7dOq1fv14ffvih5syZoyVLlujECf9XvwCA0WlIAujZZ5/VqlWrdO+99+qrX/2qXnzxRY0ZM0a/+tWvhmJzAIARKOkB1NfXp/3796uiouLvG0lLU0VFherr689bv7e3Vx0dHQkLAGD0S3oAff755xoYGFBhYWHC/YWFhWppaTlv/ZqaGkWj0fjCFXAAcGUw/yBqdXW12tvb48uxY8esWwIADIOkXwWXn5+vcDis1tbWhPtbW1tVVFR03vqRSESRiP9VTgCAkS3pZ0CZmZmaO3eudu3aFb8vFotp165dmj9/frI3BwAYoYbkc0Dr1q3TihUr9I1vfEPz5s3Tc889p66uLt17771DsTkAwAg0JAF0xx136LPPPtMTTzyhlpYWfe1rX9POnTvPuzABAHDlGrJJCGvWrNGaNWuG6tsjBfz1Bv9RPG8Wv+ld80xmp3dNWeQz7xpJCsv/E/NpAT5lPzbNfzTMgPP/jXk44ASA/9c9xbtmz9/KvGv+dcq/edf8y+Lve9dkvfUH7xoMPfOr4AAAVyYCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmhmwYKUa/9qn+r1/29Ya9a5p7c71r0uS8ayQpppB3TU8sw7smmt7tXZMVOuNdE2RQqiQ1dk/0rjnWlutd89vu8/9I5aW0Xe3/Y8t/KxgOnAEBAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwwDRuB9RQNeNdMzzjtXVMSafOuyc/o9K6RpBP9471rMkL++6E/5v9fr9tFvGtywj3eNZJUnNXuXfONIv8J5LMyj3vX9OV4lyBFcQYEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABMNIEVhmQbd3TY/zH1g5LsBAzX4X9q6RpDFpfd41nQNZ3jXtA9neNZ/3jvOuuSn3Y+8aKdg+D4diAWr8j4dYxL8GqYkzIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYYRorAxmb3etc0D0S8a2LO/3VSv3fFWRmhgWGp6Y35/9crG/O5d03bwBjvmqBOnfF/bpvP5HjX9BUHfXaRajgDAgCYIIAAACaSHkBPPvmkQqFQwjJjxoxkbwYAMMINyXtA1113nd59992/bySdt5oAAImGJBnS09NVVFQ0FN8aADBKDMl7QIcPH1ZJSYmmTp2qe+65R0ePHr3gur29vero6EhYAACjX9IDqLy8XFu2bNHOnTu1adMmNTU16aabblJnZ+eg69fU1CgajcaX0tLSZLcEAEhBSQ+gyspKffe739Xs2bO1ZMkS/eY3v1FbW5tef/31Qdevrq5We3t7fDl27FiyWwIApKAhvzogNzdX1157rY4cOTLo45FIRJGI/wfYAAAj25B/DujUqVNqbGxUcXHxUG8KADCCJD2AHn74YdXV1emTTz7R73//e912220Kh8O66667kr0pAMAIlvRfwX366ae66667dPLkSU2cOFE33nij9uzZo4kTJyZ7UwCAESzpAfTqq68m+1siReVE+rxrMhXzrkkL+ddkhYINrOx3/v8lpkZOeNdcndXiXdPcf5V3TXeA4a+SlJXmv/96YxneNR2xLO+azLH+xx1SE7PgAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmBjyP0iH0SsUct41Xc5/YGX7mTHeNUrv9q9RsCGmOeHT3jU//vifvWu+N7XWu+bj7iLvGknKDbD/ugYyA23LVyQSbNAsUg9nQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAE0zDRmA9Z/wPn4EAr3mae6PeNUFNzGr1rsnQgHdN9DtHvGtmNB33rtnTebV3jSS1BZhA3tkf8a4JMn08FuN182jBMwkAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEw0gR2Kke/+GTY0N93jUx5/866ZPTE7xrJOmfxnziXRMbptdxnw3keNeUZX8WaFt7/1bmXXOi27+/zJD/INeengzvGqQmzoAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYYBgpAuvu8h9GmhGKDUEn5xtwoUB1BeFO75r/c/J/BdhSr3fF7s4Z3jX/Ev3Qu0aS/q15lndNzxn/Hyc5aT3eNQOn+bE1WnAGBAAwQQABAEx4B9Du3bt1yy23qKSkRKFQSNu3b0943DmnJ554QsXFxcrOzlZFRYUOHz6crH4BAKOEdwB1dXVpzpw52rhx46CPb9iwQc8//7xefPFF7d27V2PHjtWSJUvU0+P/u14AwOjl/W5eZWWlKisrB33MOafnnntOjz32mG699VZJ0ksvvaTCwkJt375dd9555+V1CwAYNZL6HlBTU5NaWlpUUVERvy8ajaq8vFz19fWD1vT29qqjoyNhAQCMfkkNoJaWFklSYWFhwv2FhYXxx85VU1OjaDQaX0pLS5PZEgAgRZlfBVddXa329vb4cuzYMeuWAADDIKkBVFRUJElqbW1NuL+1tTX+2LkikYjGjx+fsAAARr+kBlBZWZmKioq0a9eu+H0dHR3au3ev5s+fn8xNAQBGOO+r4E6dOqUjR47Ebzc1NenAgQPKy8vT5MmTtXbtWv3kJz/RNddco7KyMj3++OMqKSnRsmXLktk3AGCE8w6gffv26eabb47fXrdunSRpxYoV2rJlix555BF1dXVp9erVamtr04033qidO3cqKysreV0DAEY87wBauHChnHMXfDwUCunpp5/W008/fVmNIfXFuvyHQmZoeIaRBjUzs9+7ZufHX/WumaaPvGvqT5R519yf94F3jST1x/x/O5+Vfsa/JjTgXRPqDnvXIDWZXwUHALgyEUAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBM+I8zBv5HWpf/VOKxacMzDftMLNjE5HFp/n82JKMhO9C2fB1rzvOuKbwuM9C2unv96/LGdnvXjAkwDTt8mtfNowXPJADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMMI0Vg4Z6Qd01WyL/m9ECGd01+5JR3TVDR/xyeAavZjRHvmvBi//0dVFrIeddkBGgvdMa/BqmJMyAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmGEaKwMK9/pMk+5z/wMpImv/0ySCDMYMa39QzLNsZ2zx8/6YxkT7vmrHp/jVh74pgQ3CRmjgDAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIJhpAgs/fTwbCcm/+GT0YDN9bsB75r0j//Lu8Z/K9L4T/yHfQY1Jedv3jV9Mf/Rohkh/+c2PHy7AUOMMyAAgAkCCABgwjuAdu/erVtuuUUlJSUKhULavn17wuMrV65UKBRKWJYuXZqsfgEAo4R3AHV1dWnOnDnauHHjBddZunSpjh8/Hl9eeeWVy2oSADD6eF+EUFlZqcrKyouuE4lEVFRUFLgpAMDoNyTvAdXW1qqgoEDTp0/XAw88oJMnT15w3d7eXnV0dCQsAIDRL+kBtHTpUr300kvatWuXfvazn6murk6VlZUaGBj8wtOamhpFo9H4UlpamuyWAAApKOmfA7rzzjvjX8+aNUuzZ8/WtGnTVFtbq0WLFp23fnV1tdatWxe/3dHRQQgBwBVgyC/Dnjp1qvLz83XkyJFBH49EIho/fnzCAgAY/YY8gD799FOdPHlSxcXFQ70pAMAI4v0ruFOnTiWczTQ1NenAgQPKy8tTXl6ennrqKS1fvlxFRUVqbGzUI488oquvvlpLlixJauMAgJHNO4D27dunm2++OX77i/dvVqxYoU2bNungwYP69a9/rba2NpWUlGjx4sX68Y9/rEgkkryuAQAjnncALVy4UM65Cz7+29/+9rIawsgRDjDvMyfNf2Bl74D/tTLjwj3eNUENfPbZsGwn65MLf5wh2fIyu71r/to3Zgg6OV8oNiybwTBgFhwAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwETS/yQ3rhzppy88Ff1CMuQ/DTuIsPx7S3UDx5qHbVsTMk951wSZhp0R8n8NHDrjXYIUxRkQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwwjRWCZnf4DP8OhkHfNV3OOe9ekhWLeNZKUERqeYalBuP6+YdvWVeld3jXTx7V612SF/H8EZXSNvkGzVyrOgAAAJgggAIAJAggAYIIAAgCYIIAAACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJhgGCkCS+/1H/iZFuA1TzR82rsmktbvXSNJ/W4gUF2q6o4F2w85aT3eNeH04RkSyjDS0YMzIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYIIAAACYYRoqUlxE6410zNq030LZi8h+wmsr6FWxwZ5D91+/CgbblK6ObYaSjBWdAAAATBBAAwIRXANXU1Oj6669XTk6OCgoKtGzZMjU0NCSs09PTo6qqKk2YMEHjxo3T8uXL1dramtSmAQAjn1cA1dXVqaqqSnv27NE777yj/v5+LV68WF1dXfF1HnroIb311lt64403VFdXp+bmZt1+++1JbxwAMLJ5XYSwc+fOhNtbtmxRQUGB9u/frwULFqi9vV2//OUvtXXrVn3729+WJG3evFlf+cpXtGfPHn3zm99MXucAgBHtst4Dam9vlyTl5eVJkvbv36/+/n5VVFTE15kxY4YmT56s+vr6Qb9Hb2+vOjo6EhYAwOgXOIBisZjWrl2rG264QTNnzpQktbS0KDMzU7m5uQnrFhYWqqWlZdDvU1NTo2g0Gl9KS0uDtgQAGEECB1BVVZUOHTqkV1999bIaqK6uVnt7e3w5duzYZX0/AMDIEOiDqGvWrNHbb7+t3bt3a9KkSfH7i4qK1NfXp7a2toSzoNbWVhUVFQ36vSKRiCKRSJA2AAAjmNcZkHNOa9as0bZt2/Tee++prKws4fG5c+cqIyNDu3btit/X0NCgo0ePav78+cnpGAAwKnidAVVVVWnr1q3asWOHcnJy4u/rRKNRZWdnKxqN6r777tO6deuUl5en8ePH68EHH9T8+fO5Ag4AkMArgDZt2iRJWrhwYcL9mzdv1sqVKyVJP//5z5WWlqbly5ert7dXS5Ys0S9+8YukNAsAGD28Asi5Sw8BzMrK0saNG7Vx48bATWFkSO8ensGdAwGulQkywFSSelywulTVGQs2uDMt5P/cZoQG/LcT4LnNbB9dz9GVjFlwAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATgf4iKiBJ4e7UnUocVrAp0H8d8J/onMo+G8gOVBd0/w2HzJOnvWuGZ247fHEGBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwATDSJHyYs7/dVJGKNig1H/vKwhUl6paBqKB6rLS+rxrwrGsQNvylfZZm3cNw0hTE2dAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADABAEEADBBAAEATDCMFIGFBpx3TX1veAg6SZ7/HGXDSBt6igPV/VP2J941aQFGfv7f7hzvGtfV5V2D1MQZEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMMI0VgsSz/waIT0k571xRmtHnXTE7/m3eNJLWcyQ1Ul6r+41SwYaSVOf/uXdPtIt41ueFu7xql82NrtOAMCABgggACAJjwCqCamhpdf/31ysnJUUFBgZYtW6aGhoaEdRYuXKhQKJSw3H///UltGgAw8nkFUF1dnaqqqrRnzx6988476u/v1+LFi9V1zh+IWrVqlY4fPx5fNmzYkNSmAQAjn9e7eTt37ky4vWXLFhUUFGj//v1asGBB/P4xY8aoqKgoOR0CAEaly3oPqL29XZKUl5eXcP/LL7+s/Px8zZw5U9XV1eruvvCVLr29vero6EhYAACjX+DrGWOxmNauXasbbrhBM2fOjN9/9913a8qUKSopKdHBgwf16KOPqqGhQW+++eag36empkZPPfVU0DYAACNU4ACqqqrSoUOH9MEHHyTcv3r16vjXs2bNUnFxsRYtWqTGxkZNmzbtvO9TXV2tdevWxW93dHSotLQ0aFsAgBEiUACtWbNGb7/9tnbv3q1JkyZddN3y8nJJ0pEjRwYNoEgkokjE/wNsAICRzSuAnHN68MEHtW3bNtXW1qqsrOySNQcOHJAkFRcH+0Q2AGB08gqgqqoqbd26VTt27FBOTo5aWlokSdFoVNnZ2WpsbNTWrVv1ne98RxMmTNDBgwf10EMPacGCBZo9e/aQ/AMAACOTVwBt2rRJ0tkPm/7/Nm/erJUrVyozM1PvvvuunnvuOXV1dam0tFTLly/XY489lrSGAQCjg/ev4C6mtLRUdXV1l9UQAODKwFhZBBaKXfwFyWCuy8z2rvnVyfMvXrmU5qyrvGsk6eWmed41efo40LaGQ3a4P1Ddr07e6F2TERrwrvnhxA8uvdK5LvFCGCMHw0gBACYIIACACQIIAGCCAAIAmCCAAAAmCCAAgAkCCABgggACAJgggAAAJgggAIAJAggAYIIAAgCYCLlLjbgeZh0dHYpGo1qoW5UeyrBuBwDg6YzrV612qL29XePHj7/gepwBAQBMEEAAABMEEADABAEEADBBAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMBEunUD5/piNN0Z9UspNaUOAPCPOKN+SX//eX4hKRdAnZ2dkqQP9BvjTgAAl6Ozs1PRaPSCj6fcNOxYLKbm5mbl5OQoFAolPNbR0aHS0lIdO3bsohNWRzv2w1nsh7PYD2exH85Khf3gnFNnZ6dKSkqUlnbhd3pS7gwoLS1NkyZNuug648ePv6IPsC+wH85iP5zFfjiL/XCW9X642JnPF7gIAQBgggACAJgYUQEUiUS0fv16RSIR61ZMsR/OYj+cxX44i/1w1kjaDyl3EQIA4Mowos6AAACjBwEEADBBAAEATBBAAAATIyaANm7cqC9/+cvKyspSeXm5/vCHP1i3NOyefPJJhUKhhGXGjBnWbQ253bt365ZbblFJSYlCoZC2b9+e8LhzTk888YSKi4uVnZ2tiooKHT582KbZIXSp/bBy5crzjo+lS5faNDtEampqdP311ysnJ0cFBQVatmyZGhoaEtbp6elRVVWVJkyYoHHjxmn58uVqbW016nho/CP7YeHChecdD/fff79Rx4MbEQH02muvad26dVq/fr0+/PBDzZkzR0uWLNGJEyesWxt21113nY4fPx5fPvjgA+uWhlxXV5fmzJmjjRs3Dvr4hg0b9Pzzz+vFF1/U3r17NXbsWC1ZskQ9PT3D3OnQutR+kKSlS5cmHB+vvPLKMHY49Orq6lRVVaU9e/bonXfeUX9/vxYvXqyurq74Og899JDeeustvfHGG6qrq1Nzc7Nuv/12w66T7x/ZD5K0atWqhONhw4YNRh1fgBsB5s2b56qqquK3BwYGXElJiaupqTHsavitX7/ezZkzx7oNU5Lctm3b4rdjsZgrKipyzzzzTPy+trY2F4lE3CuvvGLQ4fA4dz8459yKFSvcrbfeatKPlRMnTjhJrq6uzjl39rnPyMhwb7zxRnydP/3pT06Sq6+vt2pzyJ27H5xz7lvf+pb7/ve/b9fUPyDlz4D6+vq0f/9+VVRUxO9LS0tTRUWF6uvrDTuzcfjwYZWUlGjq1Km65557dPToUeuWTDU1NamlpSXh+IhGoyovL78ij4/a2loVFBRo+vTpeuCBB3Ty5EnrloZUe3u7JCkvL0+StH//fvX39yccDzNmzNDkyZNH9fFw7n74wssvv6z8/HzNnDlT1dXV6u7utmjvglJuGOm5Pv/8cw0MDKiwsDDh/sLCQv35z3826spGeXm5tmzZounTp+v48eN66qmndNNNN+nQoUPKycmxbs9ES0uLJA16fHzx2JVi6dKluv3221VWVqbGxkb96Ec/UmVlperr6xUOh63bS7pYLKa1a9fqhhtu0MyZMyWdPR4yMzOVm5ubsO5oPh4G2w+SdPfdd2vKlCkqKSnRwYMH9eijj6qhoUFvvvmmYbeJUj6A8HeVlZXxr2fPnq3y8nJNmTJFr7/+uu677z7DzpAK7rzzzvjXs2bN0uzZszVt2jTV1tZq0aJFhp0NjaqqKh06dOiKeB/0Yi60H1avXh3/etasWSouLtaiRYvU2NioadOmDXebg0r5X8Hl5+crHA6fdxVLa2urioqKjLpKDbm5ubr22mt15MgR61bMfHEMcHycb+rUqcrPzx+Vx8eaNWv09ttv6/3330/48y1FRUXq6+tTW1tbwvqj9Xi40H4YTHl5uSSl1PGQ8gGUmZmpuXPnateuXfH7YrGYdu3apfnz5xt2Zu/UqVNqbGxUcXGxdStmysrKVFRUlHB8dHR0aO/evVf88fHpp5/q5MmTo+r4cM5pzZo12rZtm9577z2VlZUlPD537lxlZGQkHA8NDQ06evToqDoeLrUfBnPgwAFJSq3jwfoqiH/Eq6++6iKRiNuyZYv74x//6FavXu1yc3NdS0uLdWvD6gc/+IGrra11TU1N7ne/+52rqKhw+fn57sSJE9atDanOzk730UcfuY8++shJcs8++6z76KOP3F/+8hfnnHM//elPXW5urtuxY4c7ePCgu/XWW11ZWZk7ffq0cefJdbH90NnZ6R5++GFXX1/vmpqa3Lvvvuu+/vWvu2uuucb19PRYt540DzzwgItGo662ttYdP348vnR3d8fXuf/++93kyZPde++95/bt2+fmz5/v5s+fb9h18l1qPxw5csQ9/fTTbt++fa6pqcnt2LHDTZ061S1YsMC480QjIoCcc+6FF15wkydPdpmZmW7evHluz5491i0NuzvuuMMVFxe7zMxM96Uvfcndcccd7siRI9ZtDbn333/fSTpvWbFihXPu7KXYjz/+uCssLHSRSMQtWrTINTQ02DY9BC62H7q7u93ixYvdxIkTXUZGhpsyZYpbtWrVqHuRNti/X5LbvHlzfJ3Tp0+7733ve+6qq65yY8aMcbfddps7fvy4XdND4FL74ejRo27BggUuLy/PRSIRd/XVV7sf/vCHrr293bbxc/DnGAAAJlL+PSAAwOhEAAEATBBAAAATBBAAwAQBBAAwQQABAEwQQAAAEwQQAMAEAQQAMEEAAQBMEEAAABMEEADAxH8DR+VYJGWV2nkAAAAASUVORK5CYII=",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plt.figure()\n",
    "plt.imshow(img)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "id": "8e1c07cc-b2bc-4902-a9a6-4ac7f02c5fe4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[ 2.44647     3.8623989   0.14587203  3.2146688   1.0799949  -2.5363288\n",
      "  0.86715794 -3.8287208  -2.02238    -2.9016623 ]\n",
      "predicted label: Trouser\n"
     ]
    }
   ],
   "source": [
    "print(predictions)\n",
    "print(\"predicted label:\", classes[np.argmax(predictions)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "3d47a8ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "# This will clear the engine cache (containing previously compiled TensorRT engines) and resets the CUDA Context.\n",
    "torch._dynamo.reset()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "281c7889",
   "metadata": {},
   "source": [
    "## Using Triton Inference Server\n",
    "In this section, we demonstrate integration with the [Triton Inference Server](https://developer.nvidia.com/nvidia-triton-inference-server), an open-source, GPU-accelerated serving solution for DL.  \n",
    "We use [PyTriton](https://github.com/triton-inference-server/pytriton), a Flask-like framework that handles client/server communication with the Triton server.  \n",
    "\n",
    "The process looks like this:\n",
    "- Distribute a PyTriton task across the Spark cluster, instructing each node to launch a Triton server process.\n",
    "- Define a Triton inference function, which contains a client that binds to the local server on a given node and sends inference requests.\n",
    "- Wrap the Triton inference function in a predict_batch_udf to launch parallel inference requests using Spark.\n",
    "- Finally, distribute a shutdown signal to terminate the Triton server processes on each node.\n",
    "\n",
    "<img src=\"../images/spark-server.png\" alt=\"drawing\" width=\"700\"/>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "id": "53ca290a-ccc3-4923-a292-944921bab36d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d8abea75",
   "metadata": {},
   "source": [
    "Import the helper class from server_utils.py:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "id": "e616b207",
   "metadata": {},
   "outputs": [],
   "source": [
    "sc.addPyFile(\"server_utils.py\")\n",
    "\n",
    "from server_utils import TritonServerManager"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "606934ac",
   "metadata": {},
   "source": [
    "Define the Triton Server function:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "8fa92fe4-2e04-4d82-a357-bfdfca38bd8c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def triton_server(ports, model_path):\n",
    "    import time\n",
    "    import signal\n",
    "    import numpy as np\n",
    "    import torch\n",
    "    from torch import nn\n",
    "    import torch_tensorrt as trt\n",
    "    from pytriton.decorators import batch\n",
    "    from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor\n",
    "    from pytriton.triton import Triton, TritonConfig\n",
    "    from pyspark import TaskContext\n",
    "\n",
    "    print(f\"SERVER: Initializing model on worker {TaskContext.get().partitionId()}.\")\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    \n",
    "    exp_program = torch.export.load(model_path)\n",
    "    example_inputs = (torch.randn((50, 784), dtype=torch.float).to(\"cuda\"),)\n",
    "    trt_gm = trt.dynamo.compile(exp_program,\n",
    "                                tuple(example_inputs),\n",
    "                                enabled_precisions={torch.float},\n",
    "                                workspace_size=1<<30)\n",
    "\n",
    "    print(\"SERVER: Compiled model.\")\n",
    "\n",
    "    @batch\n",
    "    def _infer_fn(**inputs):\n",
    "        images = inputs[\"images\"]\n",
    "        if len(images) != 1:\n",
    "            images = np.squeeze(images)\n",
    "        stream = torch.cuda.Stream()\n",
    "        with torch.no_grad(), torch.cuda.stream(stream):\n",
    "            torch_inputs = torch.from_numpy(images).to(device)\n",
    "            outputs = trt_gm(torch_inputs)\n",
    "            return {\n",
    "                \"labels\": outputs.cpu().numpy(),\n",
    "            }\n",
    "        \n",
    "    workspace_path = f\"/tmp/triton_{time.strftime('%m_%d_%M_%S')}\"\n",
    "    triton_conf = TritonConfig(http_port=ports[0], grpc_port=ports[1], metrics_port=ports[2])\n",
    "    with Triton(config=triton_conf, workspace=workspace_path) as triton:\n",
    "        triton.bind(\n",
    "            model_name=\"ImageClassifier\",\n",
    "            infer_func=_infer_fn,\n",
    "            inputs=[\n",
    "                Tensor(name=\"images\", dtype=np.float32, shape=(-1,)),\n",
    "            ],\n",
    "            outputs=[\n",
    "                Tensor(name=\"labels\", dtype=np.float32, shape=(-1,)),\n",
    "            ],\n",
    "            config=ModelConfig(\n",
    "                max_batch_size=64,\n",
    "                batcher=DynamicBatcher(max_queue_delay_microseconds=5000),  # 5ms\n",
    "            ),\n",
    "            strict=True,\n",
    "        )\n",
    "\n",
    "        def _stop_triton(signum, frame):\n",
    "            # The server manager sends SIGTERM to stop the server; this function ensures graceful cleanup.\n",
    "            print(\"SERVER: Received SIGTERM. Stopping Triton server.\")\n",
    "            triton.stop()\n",
    "\n",
    "        signal.signal(signal.SIGTERM, _stop_triton)\n",
    "\n",
    "        print(\"SERVER: Serving inference\")\n",
    "        triton.serve()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8fea6e5e",
   "metadata": {},
   "source": [
    "#### Start Triton servers  "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f837300c",
   "metadata": {},
   "source": [
    "The `TritonServerManager` will handle the lifecycle of Triton server instances across the Spark cluster:\n",
    "- Find available ports for HTTP/gRPC/metrics\n",
    "- Deploy a server on each node via stage-level scheduling\n",
    "- Gracefully shutdown servers across nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "id": "f72c53d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "model_name = \"ImageClassifier\"\n",
    "server_manager = TritonServerManager(model_name=model_name, model_path=model_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "65d3f7be",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-07 11:03:44,809 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n",
      "2025-02-07 11:03:44,810 - INFO - Starting 1 servers.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'cb4ae00-lcedt': (2020631, [7000, 7001, 7002])}"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Returns {'hostname', (server_pid, [http_port, grpc_port, metrics_port])}\n",
    "server_manager.start_servers(triton_server)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "90ed191b",
   "metadata": {},
   "source": [
    "#### Define client function"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "86c1545a",
   "metadata": {},
   "source": [
    "Get the hostname -> url mapping from the server manager:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c4c2833f",
   "metadata": {},
   "outputs": [],
   "source": [
    "host_to_http_url = server_manager.host_to_http_url  # or server_manager.host_to_grpc_url"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6771c93",
   "metadata": {},
   "source": [
    "Define the Triton inference function, which returns a predict function for batch inference through the server:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "id": "cec9a48c",
   "metadata": {},
   "outputs": [],
   "source": [
    "def triton_fn(model_name, host_to_url):\n",
    "    import socket\n",
    "    from pytriton.client import ModelClient\n",
    "\n",
    "    url = host_to_url[socket.gethostname()]\n",
    "    print(f\"Connecting to Triton model {model_name} at {url}.\")\n",
    "\n",
    "    def infer_batch(inputs):\n",
    "        with ModelClient(url, model_name, inference_timeout_s=240) as client:\n",
    "            result_data = client.infer_batch(inputs)\n",
    "            return result_data[\"labels\"]\n",
    "        \n",
    "    return infer_batch"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "0262fd4a-9845-44b9-8c75-1c105e7deeca",
   "metadata": {},
   "outputs": [],
   "source": [
    "mnist = predict_batch_udf(partial(triton_fn, model_name=model_name, host_to_url=host_to_http_url),\n",
    "                          input_tensor_shapes=[[784]],\n",
    "                          return_type=ArrayType(FloatType()),\n",
    "                          batch_size=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30a4362d-7514-4b84-b238-f704a97e1e72",
   "metadata": {},
   "source": [
    "#### Run inference"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "ab94d4d1-dac6-4474-9eb0-59478aa98f7d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "StructType([StructField('data', ArrayType(FloatType(), True), True)])"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = spark.read.parquet(data_path_1)\n",
    "df.schema"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "fc5f6baa-052e-4b89-94b6-4821cf01952a",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 157 ms, sys: 47.6 ms, total: 205 ms\n",
      "Wall time: 2.49 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "preds = df.withColumn(\"preds\", mnist(struct(df.columns))).collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "a85dea35-e41d-482d-8a8f-52d3c108f038",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 183 ms, sys: 60.3 ms, total: 243 ms\n",
      "Wall time: 1.49 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "preds = df.withColumn(\"preds\", mnist(*df.columns)).collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "bc3f0dbe-c52b-41d6-8097-8cebaa5ee5a8",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 383 ms, sys: 43.9 ms, total: 427 ms\n",
      "Wall time: 1.6 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "preds = df.withColumn(\"preds\", mnist(*[col(c) for c in df.columns])).collect()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 77,
   "id": "99fb5e8d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Predicted label: Bag\n"
     ]
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAJJ5JREFUeJzt3Xt0lfWd7/HPTkg2t2THEHKTQAMotHLplErKqBRLDpDOuEA5HW9zDrg6MNLgqlKrJz1Watuz0uIa66lDca2zWqir4oVzREbHYgUljAp0QBjGXlLAKGEgoWCTDQlJdrJ/5w/GzERB+P5M8kvC+7XWXovs/Xx4fnnyJJ882TvfRJxzTgAA9LKU0AsAAFyaKCAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQQwKvYAPSyaTOnr0qDIyMhSJREIvBwBg5JzTqVOnVFhYqJSU81/n9LkCOnr0qIqKikIvAwDwCdXW1mrUqFHnfbzPFVBGRoYk6Vp9WYOUFng16Ha9dVXLhCkgmHYl9Lpe6vx6fj49VkCrV6/Www8/rLq6Ok2dOlWPPfaYpk+ffsHcBz92G6Q0DYpQQANOr/1YlQICgvn3T78LPY3SIy9CeOaZZ7RixQqtXLlSb731lqZOnaq5c+fq+PHjPbE7AEA/1CMF9Mgjj2jJkiW644479JnPfEaPP/64hg4dqp/97Gc9sTsAQD/U7QXU1tamPXv2qLS09D92kpKi0tJS7dix4yPbt7a2Kh6Pd7kBAAa+bi+gEydOqKOjQ3l5eV3uz8vLU11d3Ue2r6ysVCwW67zxCjgAuDQE/0XUiooKNTY2dt5qa2tDLwkA0Au6/VVwOTk5Sk1NVX19fZf76+vrlZ+f/5Hto9GootFody8DANDHdfsVUHp6uqZNm6atW7d23pdMJrV161bNmDGju3cHAOineuT3gFasWKFFixbp85//vKZPn65HH31UTU1NuuOOO3pidwCAfqhHCujmm2/WH//4Rz344IOqq6vTZz/7WW3evPkjL0wAAFy6Is71rZkl8XhcsVhMszSfSQgYsFLHF5szx/7O/lxp7vzfmzPAJ9XuEtqmTWpsbFRmZuZ5twv+KjgAwKWJAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEH0yDRsoL+q+YH9b1bdP3+jOVPVcP4BjeczPu2MOZO/356RpF9smG3OFH3vTa994dLFFRAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCYBp2HxYZZP/wuI4O+46cs2c8RdLSzRmXaDNnBhWPMWckacttD5szX/zV3ebMlX+z25ypNyek3Tdf75GSHvzuU+bM2u/5HXOzSMSe6cVzHBePKyAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACCLiXN+a0hePxxWLxTRL8zUokhZ6Od3HY4BiJDXVnPEaRuqrb506Xfxh7TSv3NiiP5ozg0oPe+2rLzvzcrE5c8Pl+82ZLZMyzBn0fe0uoW3apMbGRmVmZp53O66AAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACCIQaEXcMnwGNzp2tvt+/EYetqXh4pKUsrUT5szW770v732VfrLFebMlfIYRppiHzQbSbF/bL3OIUlpP8w2Z25e9y/mzIbF3zRnLlu3w5xB38QVEAAgCAoIABBEtxfQd77zHUUikS63iRMndvduAAD9XI88B3TVVVdpy5Yt/7GTQTzVBADoqkeaYdCgQcrPz++J/xoAMED0yHNABw4cUGFhocaOHavbb79dhw+f/1VCra2tisfjXW4AgIGv2wuopKRE69at0+bNm7VmzRrV1NTouuuu06lTp865fWVlpWKxWOetqKiou5cEAOiDur2AysrK9JWvfEVTpkzR3Llz9dJLL6mhoUHPPvvsObevqKhQY2Nj5622tra7lwQA6IN6/NUBWVlZuvLKK3Xw4MFzPh6NRhWNRnt6GQCAPqbHfw/o9OnTOnTokAoKCnp6VwCAfqTbC+jee+9VVVWV3n33Xb355pu68cYblZqaqltvvbW7dwUA6Me6/UdwR44c0a233qqTJ09q5MiRuvbaa7Vz506NHDmyu3cFAOjHur2Ann766e7+L2HhM1jUYzCmJCnZYY7Eb/2COTNm+R/MmcdPXmfOSFLhq700ncolzZFIdKh9N57DSOu+YH9e9p1EpjnzzEMPmzMr//bL5kz9DH69oy9iFhwAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABNHjf5Cu10Qi9ozP4M7e5DMk1GPIpc9QUV9Xf2OPOVPdmGfONLenmzOSNPzZnV45q0iq5wDYXpJ2yp55s+kKc+bRP33KnLlr1BZz5oHblpgzkpS53uN86K2vRT778d1XD+EKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEEMnGnYfZ3H5Fqficku0XuTrYdtH2nOtDv7mOXUFPuE73c3jTVnJKlAdV45K9fh8XFqS3T/Qs4j77E3zZlvVVSbM9cevcqcWXlgvjlz8//cbM5I0ssvFJkzyVMeo8R9+E617kN/OYArIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIYuAMI/UYlhdJS/fbVaLNI2Rfn9d+PBy978+9ct/MfdacefLfvmDOJGUfnljwiH2YZq/qw+eDr181p5kz/33MTnNm1d455kxqkd8wzYxf2r9GNF7rtaveE/G47nA9M+SYKyAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACCLinMdUxB4Uj8cVi8U0S/M1KGIfbjiQHF9uHxLammXfz6+WrrKHJK1r+Lw5Myr9fXPme7+8yZyJ/cE+wFSSrv+bXebMxt981pwZ9G9RcyalzeN9ivh9evvsq3VE0pwZOeGEOROLtpgzZ9r9vpasHP8P5syyDUvNmeL/scOc6cvaXULbtEmNjY3KzMw873ZcAQEAgqCAAABBmAto+/btuuGGG1RYWKhIJKLnn3++y+POOT344IMqKCjQkCFDVFpaqgMHDnTXegEAA4S5gJqamjR16lStXr36nI+vWrVKP/7xj/X4449r165dGjZsmObOnauWFvvPbQEAA5f5L6KWlZWprKzsnI855/Too4/qgQce0Pz58yVJTzzxhPLy8vT888/rlltu+WSrBQAMGN36HFBNTY3q6upUWlraeV8sFlNJSYl27Dj3qzxaW1sVj8e73AAAA1+3FlBdXZ0kKS8vr8v9eXl5nY99WGVlpWKxWOetqKioO5cEAOijgr8KrqKiQo2NjZ232tra0EsCAPSCbi2g/Px8SVJ9fX2X++vr6zsf+7BoNKrMzMwuNwDAwNetBVRcXKz8/Hxt3bq18754PK5du3ZpxowZ3bkrAEA/Z34V3OnTp3Xw4MHOt2tqarRv3z5lZ2dr9OjRuvvuu/X9739fV1xxhYqLi/Xtb39bhYWFWrBgQXeuGwDQz5kLaPfu3br++us7316xYoUkadGiRVq3bp3uu+8+NTU1aenSpWpoaNC1116rzZs3a/Dgwd23agBAv3dJDyN9Z5XfjwX/7safmzMr/vmvzJn09HZz5gdTnzNnNjdMMWckKUX2U+d3jXkX3uhD3tt7uTnTcVnCnJGkwbXp5kzmO/bjkNJuz3Sk2weEJs3fYp7lUu2ZZJp9fZGk/Ticntlszny26Ig5I0lHT8fMmWvy3jFn3nrf/urfI+9nmTOSNPor/+qVs2AYKQCgT6OAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAIz1m5A8Nvb/97r9zCg39hzuT8o/3PUZxeeMqc+dnR68yZxja/P5VxR9Eb5syJtmHmTM3gpDmjDvtkZklqu8y+r8RX/mTOjIo1mjMjo6fNmWiqfaK6JGUNsk+cTniM0G7qiJozMzOr7ftJ2vcjSW8OGm/OpMp+Dg0b1GbOvDR9jTkjSbfefq85E3typ9e+LoQrIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIYsAMI225Ybo5kxbZ57evb+WbMzn/6z1z5uXxz5kzD5+wH4ehKfZBiJK08vUF5kxK3H7KuSyPgZoe80vP7ithzsQPXGbOHGwYYc7U2meeKrXV2UOeOqL2AbDOY2bs6+mfM2duX/yKfUeSrsv6gznzucGHzZmX064yZ/7in+80ZyTpbyp+Zc68/GSm174uhCsgAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAhiwAwjrf98770rWT+sNWf+MudfzJmfNtgHFP5V1j+bM9+t/UtzRpIu/2WqOZMY6jF9UmnmRCTpN4TTpdjX15Fu308yzb4+n7W1Zvkcb0kesYjHzFif/Qz/N/uk2cf/6Xr7jiT9Yf4ac+Y3bfZ3alFsvznzj5mTzRlJ+uvYv5ozv/qzpabtIx2t0r9suuB2XAEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBADZhhp2une29e9l282Zzb8abo5k5seN2e+eei/mjPNf3+5OSNJZ3Lt37/4DMdMbTNHlEz1G8KZ0ksDNX0Gdzr77FevtUlS+xC/nJXPcTg12n7eFVTZ9yNJX5s205wpHNxgzlSfzjNnFl6+15yRpBEp9g/umcuHmbZvT6RKFzF/mSsgAEAQFBAAIAhzAW3fvl033HCDCgsLFYlE9Pzzz3d5fPHixYpEIl1u8+bN6671AgAGCHMBNTU1aerUqVq9evV5t5k3b56OHTvWeXvqqac+0SIBAAOP+UUIZWVlKisr+9htotGo8vPzvRcFABj4euQ5oG3btik3N1cTJkzQsmXLdPLkyfNu29raqng83uUGABj4ur2A5s2bpyeeeEJbt27VD3/4Q1VVVamsrEwdHR3n3L6yslKxWKzzVlRU1N1LAgD0Qd3+e0C33HJL578nT56sKVOmaNy4cdq2bZtmz579ke0rKiq0YsWKzrfj8TglBACXgB5/GfbYsWOVk5OjgwcPnvPxaDSqzMzMLjcAwMDX4wV05MgRnTx5UgUFBT29KwBAP2L+Edzp06e7XM3U1NRo3759ys7OVnZ2th566CEtXLhQ+fn5OnTokO677z6NHz9ec+fO7daFAwD6N3MB7d69W9dff33n2x88f7No0SKtWbNG+/fv189//nM1NDSosLBQc+bM0fe+9z1Fo9HuWzUAoN8zF9CsWbPknDvv4y+//PInWpCvQc29t68DbfbfcXo43z448P+dtj8f1viz/2LOJGN+EysTw+y5yLlfDPmxOtLtGd8hnB9zap9/V700WNRrQKjncRh0xp5J8Rga63M++By7jqjfgfj1+qnmzIYVD5sz/5Q+zpy5esi75owkxZNJc2ZY9QnT9u0drRe1HbPgAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEES3/0nuUIacsE949TUu7bg5c6TdPl74vpfuNGcyRti/p0gMM0ckSelxe8Z5fMvTMdiekcdUa0lqH+qxK4/pzD7r85m67Ssx3L7ApMdXk9Q2+5TqlIsbtNxF/FN+07AzDtuPw7cOzzdn/u+4LebMtjMeJ6ukwtRT5kzHgXds27vERW3HFRAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABDFghpFmHjptziRch9e+xqa1mDMzdywzZ/LfsA9C/NNEc8R7yGVLjj3TEbW/T9H3PQZJen5r5TMstX2o/X1qz7CfeynDL27A43+WbPWZlCpFj6aZM2mn7R8nr2PnMZw2tdVvGOmZXHvunfVXmDO/u/8fzBlpuEdGuixliDmTOmG8aXvX0SoduPB2XAEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBADZhhpSnObOZMW8RvU+K9tmebM8F/ZBwf+aaLHAMWkPdIxxD4QUpLasuw7i56wH/P0U/b1pS84bs5I0snGYeZMR5v90yj9cNScyd5u/34x4veh1Zls+7nXXOgxWNRjGKl85or6zSJVIsO+viEp9o/T/J13mjO/nPETc0aS2mU/9yIJ28TiSPLitucKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCGDDDSNXeYY40Js947epvd37NnEkbZZ+G2Jpjf59SWu37iST8JjWm5zebM2N/FDdn2t87Ys7UR0vMGUkav/WkPXS01p7JHWGOvHdTrjnTPMY2RPIDkaH24b7ujMdw36TH+dpuz3Sk+k1ldYPsueZR9kzGTvuw4rrpQ80ZSRqXZr/uaH/nXdv2LnFR23EFBAAIggICAARhKqDKykpdffXVysjIUG5urhYsWKDq6uou27S0tKi8vFwjRozQ8OHDtXDhQtXX13frogEA/Z+pgKqqqlReXq6dO3fqlVdeUSKR0Jw5c9TU1NS5zT333KMXXnhBGzZsUFVVlY4ePaqbbrqp2xcOAOjfTC9C2Lx5c5e3161bp9zcXO3Zs0czZ85UY2OjfvrTn2r9+vX60pe+JElau3atPv3pT2vnzp36whe+0H0rBwD0a5/oOaDGxkZJUnZ2tiRpz549SiQSKi0t7dxm4sSJGj16tHbs2HHO/6O1tVXxeLzLDQAw8HkXUDKZ1N13361rrrlGkyZNkiTV1dUpPT1dWVlZXbbNy8tTXV3dOf+fyspKxWKxzltRUZHvkgAA/Yh3AZWXl+vtt9/W008//YkWUFFRocbGxs5bba3H71QAAPodr19EXb58uV588UVt375do0aN6rw/Pz9fbW1tamho6HIVVF9fr/z8/HP+X9FoVNFo1GcZAIB+zHQF5JzT8uXLtXHjRr366qsqLi7u8vi0adOUlpamrVu3dt5XXV2tw4cPa8aMGd2zYgDAgGC6AiovL9f69eu1adMmZWRkdD6vE4vFNGTIEMViMX31q1/VihUrlJ2drczMTN11112aMWMGr4ADAHRhKqA1a9ZIkmbNmtXl/rVr12rx4sWSpB/96EdKSUnRwoUL1draqrlz5+onP/lJtywWADBwmArIuQsP2Rs8eLBWr16t1atXey/KS6r99RQvNo268Ebn4JIeGY95n6nNHkMDh3sMMI34vRYl2Wp/CrH5ypHmTHrNe+bM5c8eMmckqf4vxpozJ6/JMGeyRpw2Z1pP24fnRt5PN2ckSQ1p9n35zLT1OfV8Mn6zSBVJ2HfmhtsHwDYX2vfz11v+1pyRpJq//D/mTOqIbNP2LtkmvX/h7ZgFBwAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCC8/iJqX9TxuwPmzLCUVq99PXHNT82Zv04sNWcizanmTGosYc4kO/wmJrsW++kTX95ozqTcdaU509xqn+YsSc6dsofeH2KONL6bZc5E7IPOpTTPMdD2U89r4rRL9Qh5TKOP+Iyjl+QG2w96aoP986JjmP2dSq/vvS/fLX9WfOGN/pP29hbptQtvxxUQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAAQxYIaR+shKafbKpXpMXfzRdU+bM4WD/mTO/OLkn5szL74xzZyRJHXYBzy+fyTLnElttn+f5HrxW6uIx0BNl2YfPun8ZsZ6ibR7DO/0mffpMyvVYz8uxXMoq8c5nox67ssoJeE3YLU52WbOtGXZqqI9cXHbcwUEAAiCAgIABEEBAQCCoIAAAEFQQACAICggAEAQFBAAIAgKCAAQBAUEAAiCAgIABEEBAQCCoIAAAEFc0sNIt5y6yis3Zehhc+Zyj8GiE9LazZlPDT5pzvy3Wf9kzkjS+t9cbQ8dHmKOpNgPgxIx+7BPSYp4DJ+MJO2ZlDN+gyStXO/sRpIU8ZjB6VLsC/QZNOuztrP78nmnfM4h+24GtdgzklTT3mHODD6RMG3f3n5x23MFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBXNrDSI9O8MrFRjebM+PS/mjO/CJ+pTmTFrEPGizL2G/OSNKUz9eaMyNKTpszean2zMb4n5kzkvTSUfuA2vak/fu45rY0cybpsZ9omm2I5AeGeAzC9TkOg1LsUzh95oqebol6pKTYEPvEz8LhjeZMW0eqOZP0mcoq6Y0z48yZYzMGm7bvaJV0ETOOuQICAARBAQEAgjAVUGVlpa6++mplZGQoNzdXCxYsUHV1dZdtZs2apUgk0uV25513duuiAQD9n6mAqqqqVF5erp07d+qVV15RIpHQnDlz1NTU1GW7JUuW6NixY523VatWdeuiAQD9n+lFCJs3b+7y9rp165Sbm6s9e/Zo5syZnfcPHTpU+fn53bNCAMCA9ImeA2psPPtqj+zs7C73P/nkk8rJydGkSZNUUVGh5ubzv2qstbVV8Xi8yw0AMPB5vww7mUzq7rvv1jXXXKNJkyZ13n/bbbdpzJgxKiws1P79+3X//ferurpazz333Dn/n8rKSj300EO+ywAA9FPeBVReXq63335br7/+epf7ly5d2vnvyZMnq6CgQLNnz9ahQ4c0btxHX39eUVGhFStWdL4dj8dVVFTkuywAQD/hVUDLly/Xiy++qO3bt2vUqFEfu21JSYkk6eDBg+csoGg0qmjU75fEAAD9l6mAnHO66667tHHjRm3btk3FxcUXzOzbt0+SVFBQ4LVAAMDAZCqg8vJyrV+/Xps2bVJGRobq6uokSbFYTEOGDNGhQ4e0fv16ffnLX9aIESO0f/9+3XPPPZo5c6amTJnSI+8AAKB/MhXQmjVrJJ39ZdP/bO3atVq8eLHS09O1ZcsWPfroo2pqalJRUZEWLlyoBx54oNsWDAAYGMw/gvs4RUVFqqqq+kQLAgBcGi7padhDPCcFL8v6jTlzMBExZ8qz7NOm/dgn8Z5l/52t5mSbOTM0Zag58+mc6gtvdA7LLttrzlyWal+fjzda7JOjG5J+a8tPtX9sB3tMYu+Q/fOixdnP15EpreaMJBWnDTdn9rTaz/HxafZjt6Mly5yRpPV/LDFnRlW+adq+3SV04CK2YxgpACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEAQFBAAIggICAARxaQ8j/brfX2Itm/B1cybaYB98mkzrne8POqJ++zlyvT0XybMPhcx4c4g5U/jLo+aMJLlU+/vUcdkwcyb1VIs5o2PHzRGXaLfvR1JkqH2IaWS4x+DTC0zYP6d2++BOX81X2f+Qps/n7fA9h82Z9mN15sxZ9kGzPYUrIABAEBQQACAICggAEAQFBAAIggICAARBAQEAgqCAAABBUEAAgCAoIABAEBQQACAICggAEESfmwXn/n02VLsSkseYKNO+OuxzySSpPWGf45Xa7jELLtJLs+BS/PaTbPGYBddsP+YdbRFzpj3p97F1Ht+TdbSn2vfjc+65NnvEec6CS9q/NESS9uPgNQsu2Xuz4Nrb7Z/rSY9zqD1p/9i2O/vXlN7SrrNrcxf4+EbchbboZUeOHFFRUVHoZQAAPqHa2lqNGjXqvI/3uQJKJpM6evSoMjIyFIl0/c43Ho+rqKhItbW1yszMDLTC8DgOZ3EczuI4nMVxOKsvHAfnnE6dOqXCwkKlfMxPWPrcj+BSUlI+tjElKTMz85I+wT7AcTiL43AWx+EsjsNZoY9DLBa74Da8CAEAEAQFBAAIol8VUDQa1cqVKxWN+v0l04GC43AWx+EsjsNZHIez+tNx6HMvQgAAXBr61RUQAGDgoIAAAEFQQACAICggAEAQ/aaAVq9erU996lMaPHiwSkpK9Otf/zr0knrdd77zHUUikS63iRMnhl5Wj9u+fbtuuOEGFRYWKhKJ6Pnnn+/yuHNODz74oAoKCjRkyBCVlpbqwIEDYRbbgy50HBYvXvyR82PevHlhFttDKisrdfXVVysjI0O5ublasGCBqquru2zT0tKi8vJyjRgxQsOHD9fChQtVX18faMU942KOw6xZsz5yPtx5552BVnxu/aKAnnnmGa1YsUIrV67UW2+9palTp2ru3Lk6fvx46KX1uquuukrHjh3rvL3++uuhl9TjmpqaNHXqVK1evfqcj69atUo//vGP9fjjj2vXrl0aNmyY5s6dq5YW+yDJvuxCx0GS5s2b1+X8eOqpp3pxhT2vqqpK5eXl2rlzp1555RUlEgnNmTNHTU1Nndvcc889euGFF7RhwwZVVVXp6NGjuummmwKuuvtdzHGQpCVLlnQ5H1atWhVoxefh+oHp06e78vLyzrc7OjpcYWGhq6ysDLiq3rdy5Uo3derU0MsISpLbuHFj59vJZNLl5+e7hx9+uPO+hoYGF41G3VNPPRVghb3jw8fBOecWLVrk5s+fH2Q9oRw/ftxJclVVVc65sx/7tLQ0t2HDhs5tfve73zlJbseOHaGW2eM+fBycc+6LX/yi+/rXvx5uURehz18BtbW1ac+ePSotLe28LyUlRaWlpdqxY0fAlYVx4MABFRYWauzYsbr99tt1+PDh0EsKqqamRnV1dV3Oj1gsppKSkkvy/Ni2bZtyc3M1YcIELVu2TCdPngy9pB7V2NgoScrOzpYk7dmzR4lEosv5MHHiRI0ePXpAnw8fPg4fePLJJ5WTk6NJkyapoqJCzc3NIZZ3Xn1uGOmHnThxQh0dHcrLy+tyf15enn7/+98HWlUYJSUlWrdunSZMmKBjx47poYce0nXXXae3335bGRkZoZcXRF1dnSSd8/z44LFLxbx583TTTTepuLhYhw4d0re+9S2VlZVpx44dSk31+Fs9fVwymdTdd9+ta665RpMmTZJ09nxIT09XVlZWl20H8vlwruMgSbfddpvGjBmjwsJC7d+/X/fff7+qq6v13HPPBVxtV32+gPAfysrKOv89ZcoUlZSUaMyYMXr22Wf11a9+NeDK0Bfccsstnf+ePHmypkyZonHjxmnbtm2aPXt2wJX1jPLycr399tuXxPOgH+d8x2Hp0qWd/548ebIKCgo0e/ZsHTp0SOPGjevtZZ5Tn/8RXE5OjlJTUz/yKpb6+nrl5+cHWlXfkJWVpSuvvFIHDx4MvZRgPjgHOD8+auzYscrJyRmQ58fy5cv14osv6rXXXuvy51vy8/PV1tamhoaGLtsP1PPhfMfhXEpKSiSpT50Pfb6A0tPTNW3aNG3durXzvmQyqa1bt2rGjBkBVxbe6dOndejQIRUUFIReSjDFxcXKz8/vcn7E43Ht2rXrkj8/jhw5opMnTw6o88M5p+XLl2vjxo169dVXVVxc3OXxadOmKS0trcv5UF1drcOHDw+o8+FCx+Fc9u3bJ0l963wI/SqIi/H000+7aDTq1q1b537729+6pUuXuqysLFdXVxd6ab3qG9/4htu2bZurqalxb7zxhistLXU5OTnu+PHjoZfWo06dOuX27t3r9u7d6yS5Rx55xO3du9e99957zjnnfvCDH7isrCy3adMmt3//fjd//nxXXFzszpw5E3jl3evjjsOpU6fcvffe63bs2OFqamrcli1b3Oc+9zl3xRVXuJaWltBL7zbLli1zsVjMbdu2zR07dqzz1tzc3LnNnXfe6UaPHu1effVVt3v3bjdjxgw3Y8aMgKvufhc6DgcPHnTf/e533e7du11NTY3btGmTGzt2rJs5c2bglXfVLwrIOecee+wxN3r0aJeenu6mT5/udu7cGXpJve7mm292BQUFLj093V1++eXu5ptvdgcPHgy9rB732muvOUkfuS1atMg5d/al2N/+9rddXl6ei0ajbvbs2a66ujrsonvAxx2H5uZmN2fOHDdy5EiXlpbmxowZ45YsWTLgvkk71/svya1du7ZzmzNnzrivfe1r7rLLLnNDhw51N954ozt27Fi4RfeACx2Hw4cPu5kzZ7rs7GwXjUbd+PHj3Te/+U3X2NgYduEfwp9jAAAE0eefAwIADEwUEAAgCAoIABAEBQQACIICAgAEQQEBAIKggAAAQVBAAIAgKCAAQBAUEAAgCAoIABAEBQQACOL/AyBNQnoqGwl/AAAAAElFTkSuQmCC",
      "text/plain": [
       "<Figure size 640x480 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "# Sample prediction\n",
    "sample = preds[0]\n",
    "predictions = sample.preds\n",
    "img = sample.data\n",
    "\n",
    "img = np.array(img).reshape(28,28)\n",
    "plt.figure()\n",
    "plt.imshow(img)\n",
    "\n",
    "print(\"Predicted label:\", classes[np.argmax(predictions)])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7a26690a-9dc4-4c36-9904-568d73e2be3c",
   "metadata": {
    "tags": []
   },
   "source": [
    "#### Stop Triton Server on each executor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e02838ba",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2025-02-04 14:00:18,330 - INFO - Requesting stage-level resources: (cores=5, gpu=1.0)\n",
      "2025-02-04 14:00:28,520 - INFO - Sucessfully stopped 1 servers.                 \n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[True]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "server_manager.stop_servers()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "a0608fff-7cfb-489e-96c9-8e1d92e57562",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not on_databricks: # on databricks, spark.stop() puts the cluster in a bad state\n",
    "    spark.stop()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "08de2664-3d60-487b-90da-6d0f3b8b9203",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "spark-dl-torch",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
